Coverage for src / qdrant_loader_core / llm / providers / gemini.py: 0%
189 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:34 +0000
1from __future__ import annotations
3from datetime import UTC, datetime
4from typing import Any
6try:
7 from google import genai # type: ignore
8 from google.genai import types as genai_types # type: ignore
10 try:
11 from google.genai import errors as genai_errors # type: ignore
12 except Exception: # pragma: no cover - optional dependency surface
13 genai_errors = None # type: ignore
14except Exception: # pragma: no cover - optional dependency at this phase
15 genai = None # type: ignore
16 genai_types = None # type: ignore
17 genai_errors = None # type: ignore
19from ...logging import LoggingConfig
20from ..errors import (
21 AuthError,
22 InvalidRequestError,
23 LLMError,
24 RateLimitedError,
25 ServerError,
26)
27from ..errors import TimeoutError as LLMTimeoutError
28from ..settings import LLMSettings
29from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter
31logger = LoggingConfig.get_logger(__name__)
34def _map_gemini_exception(exc: Exception) -> LLMError:
35 status_code: int | None = None
36 if genai_errors is not None:
37 try:
38 if isinstance(exc, genai_errors.APIError): # type: ignore[attr-defined]
39 status_code = getattr(exc, "code", None) or getattr(
40 exc, "status_code", None
41 )
42 except Exception:
43 pass
45 if status_code is None:
46 status_code = getattr(exc, "status_code", None) or getattr(exc, "code", None)
48 if isinstance(status_code, int):
49 if status_code == 408 or status_code == 504:
50 return LLMTimeoutError(str(exc))
51 if status_code == 429:
52 return RateLimitedError(str(exc))
53 if status_code in (401, 403):
54 return AuthError(str(exc))
55 if 400 <= status_code < 500:
56 return InvalidRequestError(str(exc))
57 if status_code >= 500:
58 return ServerError(str(exc))
60 if isinstance(exc, TimeoutError):
61 return LLMTimeoutError(str(exc))
63 return ServerError(str(exc))
66class _GeminiTokenCounter(TokenCounter):
67 def __init__(self, client: Any, model: str):
68 self._client = client
69 self._model = model
71 def count(self, text: str) -> int:
72 if self._client is None:
73 return len(text)
74 try:
75 response = self._client.models.count_tokens(
76 model=self._model, contents=text
77 )
78 total = getattr(response, "total_tokens", None)
79 if isinstance(total, int):
80 return total
81 except Exception:
82 pass
83 return len(text)
86def _messages_to_contents(
87 messages: list[dict[str, Any]],
88) -> tuple[str | None, list[Any]]:
89 """Convert OpenAI-style messages to (system_instruction, contents) for Gemini."""
90 system_parts: list[str] = []
91 contents: list[Any] = []
92 for msg in messages:
93 role = (msg.get("role") or "").lower()
94 content = msg.get("content")
95 if content is None:
96 continue
97 if not isinstance(content, str):
98 content = str(content)
99 if role == "system":
100 system_parts.append(content)
101 continue
102 gemini_role = "model" if role == "assistant" else "user"
103 if genai_types is not None:
104 contents.append(
105 genai_types.Content(
106 role=gemini_role,
107 parts=[genai_types.Part.from_text(text=content)],
108 )
109 )
110 else:
111 contents.append({"role": gemini_role, "parts": [{"text": content}]})
113 system_instruction = "\n".join(system_parts) if system_parts else None
114 return system_instruction, contents
117class GeminiEmbeddings(EmbeddingsClient):
118 def __init__(
119 self,
120 client: Any,
121 model: str,
122 *,
123 provider_label: str = "gemini",
124 output_dimensionality: int | None = None,
125 ):
126 self._client = client
127 self._model = model
128 self._provider_label = provider_label
129 self._output_dimensionality = output_dimensionality
131 def _build_config(self) -> Any:
132 if self._output_dimensionality is None or genai_types is None:
133 return None
134 return genai_types.EmbedContentConfig(
135 output_dimensionality=self._output_dimensionality
136 )
138 def _wrap_inputs(self, inputs: list[str]) -> list[Any]:
139 # gemini-embedding-2 aggregates a list of raw strings into ONE embedding.
140 # Wrapping each input in its own Content forces one embedding per input.
141 if genai_types is None:
142 return list(inputs)
143 return [
144 genai_types.Content(parts=[genai_types.Part.from_text(text=text)])
145 for text in inputs
146 ]
148 async def embed(self, inputs: list[str]) -> list[list[float]]:
149 if not self._client:
150 raise NotImplementedError("Gemini client not available")
151 if not self._model:
152 raise InvalidRequestError(
153 "Gemini embeddings model is not configured. "
154 "Set global.llm.models.embeddings in your config "
155 "(e.g. 'gemini-embedding-2' or 'gemini-embedding-001')."
156 )
157 if not inputs:
158 return []
159 import asyncio
161 config = self._build_config()
162 call_kwargs: dict[str, Any] = {
163 "model": self._model,
164 "contents": self._wrap_inputs(inputs),
165 }
166 if config is not None:
167 call_kwargs["config"] = config
169 started = datetime.now(UTC)
170 try:
171 response = await asyncio.to_thread(
172 self._client.models.embed_content, **call_kwargs
173 )
174 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000)
175 try:
176 logger.info(
177 "LLM request",
178 provider=self._provider_label,
179 operation="embeddings",
180 model=self._model,
181 inputs=len(inputs),
182 latency_ms=duration_ms,
183 )
184 except Exception:
185 pass
187 embeddings = getattr(response, "embeddings", None) or []
188 vectors = [list(getattr(item, "values", []) or []) for item in embeddings]
189 if len(vectors) != len(inputs):
190 raise ServerError(
191 f"Gemini embed_content returned {len(vectors)} embeddings "
192 f"for {len(inputs)} inputs (expected 1:1). "
193 f"Verify the model id {self._model!r} supports per-input embeddings."
194 )
195 return vectors
196 except Exception as exc:
197 mapped = _map_gemini_exception(exc)
198 try:
199 logger.warning(
200 "LLM error",
201 provider=self._provider_label,
202 operation="embeddings",
203 model=self._model,
204 error=type(exc).__name__,
205 )
206 except Exception:
207 pass
208 raise mapped from exc
211class GeminiChat(ChatClient):
212 def __init__(
213 self,
214 client: Any,
215 model: str,
216 *,
217 provider_label: str = "gemini",
218 ):
219 self._client = client
220 self._model = model
221 self._provider_label = provider_label
223 async def chat(
224 self, messages: list[dict[str, Any]], **kwargs: Any
225 ) -> dict[str, Any]:
226 if not self._client:
227 raise NotImplementedError("Gemini client not available")
229 model_name = kwargs.pop("model", self._model)
230 if not model_name:
231 raise InvalidRequestError(
232 "Gemini chat model is not configured. "
233 "Set global.llm.models.chat in your config "
234 "(e.g. 'gemini-2.0-flash')."
235 )
236 system_instruction, contents = _messages_to_contents(messages)
238 config_kwargs: dict[str, Any] = {}
239 if system_instruction:
240 config_kwargs["system_instruction"] = system_instruction
241 for src_key, dst_key in (
242 ("temperature", "temperature"),
243 ("top_p", "top_p"),
244 ("max_tokens", "max_output_tokens"),
245 ("stop", "stop_sequences"),
246 ("seed", "seed"),
247 ):
248 if src_key in kwargs and kwargs[src_key] is not None:
249 config_kwargs[dst_key] = kwargs[src_key]
251 config = None
252 if config_kwargs and genai_types is not None:
253 config = genai_types.GenerateContentConfig(**config_kwargs)
254 import asyncio
256 started = datetime.now(UTC)
257 try:
258 call_kwargs: dict[str, Any] = {
259 "model": model_name,
260 "contents": contents,
261 }
262 if config is not None:
263 call_kwargs["config"] = config
265 response = await asyncio.to_thread(
266 self._client.models.generate_content, **call_kwargs
267 )
268 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000)
269 try:
270 logger.info(
271 "LLM request",
272 provider=self._provider_label,
273 operation="chat",
274 model=model_name,
275 messages=len(messages),
276 latency_ms=duration_ms,
277 )
278 except Exception:
279 pass
281 text = getattr(response, "text", "") or ""
283 usage_meta = getattr(response, "usage_metadata", None)
284 normalized_usage = None
285 if usage_meta is not None:
286 normalized_usage = {
287 "prompt_tokens": getattr(usage_meta, "prompt_token_count", None),
288 "completion_tokens": getattr(
289 usage_meta, "candidates_token_count", None
290 ),
291 "total_tokens": getattr(usage_meta, "total_token_count", None),
292 }
294 return {
295 "text": text,
296 "raw": response,
297 "usage": normalized_usage,
298 "model": getattr(response, "model_version", model_name) or model_name,
299 }
300 except Exception as exc:
301 mapped = _map_gemini_exception(exc)
302 try:
303 logger.warning(
304 "LLM error",
305 provider=self._provider_label,
306 operation="chat",
307 model=model_name,
308 error=type(exc).__name__,
309 )
310 except Exception:
311 pass
312 raise mapped
315class GeminiProvider(LLMProvider):
316 def __init__(self, settings: LLMSettings):
317 self._settings = settings
318 if genai is None:
319 self._client = None
320 else:
321 kwargs: dict[str, Any] = {}
322 if settings.api_key:
323 kwargs["api_key"] = settings.api_key
324 provider_opts = settings.provider_options or {}
325 if provider_opts.get("vertexai"):
326 kwargs["vertexai"] = True
327 if provider_opts.get("project"):
328 kwargs["project"] = provider_opts["project"]
329 if provider_opts.get("location"):
330 kwargs["location"] = provider_opts["location"]
331 self._client = genai.Client(**kwargs)
333 def embeddings(self) -> EmbeddingsClient:
334 model = self._settings.models.get("embeddings", "")
335 return GeminiEmbeddings(
336 self._client,
337 model,
338 provider_label="gemini",
339 output_dimensionality=self._settings.embeddings.vector_size,
340 )
342 def chat(self) -> ChatClient:
343 model = self._settings.models.get("chat", "")
344 return GeminiChat(self._client, model, provider_label="gemini")
346 def tokenizer(self) -> TokenCounter:
347 model = self._settings.models.get("chat") or self._settings.models.get(
348 "embeddings", ""
349 )
350 return _GeminiTokenCounter(self._client, model)