Coverage for src / qdrant_loader_core / llm / providers / bedrock.py: 84%
115 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:34 +0000
1from __future__ import annotations
3import asyncio
4import json
5from datetime import UTC, datetime
6from typing import Any
8try:
9 import boto3
10except ImportError:
11 boto3 = None # type: ignore[assignment]
13from ...logging import LoggingConfig
14from ..errors import (
15 InvalidRequestError,
16 LLMError,
17 ServerError,
18)
19from ..settings import LLMSettings
20from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter
22# Re-exported exception types are intentionally kept here.
23#
24# They are used by:
25# 1. Runtime exception mapping in _map_bedrock_exception()
26# 2. Import fallback behavior when boto3 / botocore is unavailable
27# 3. Unit tests that verify the module still exposes these symbols
28#
29# Do not remove these imports even if they appear unused, as they are
30# part of the module contract and support graceful degradation.
31from .bedrock_utils import (
32 BedrockTokenizer,
33 BotoCoreError, # re-export: keep for import-fallback contract/tests
34 ClientError, # re-export: keep for import-fallback contract/tests
35 EndpointConnectionError, # re-export: keep for import-fallback contract/tests
36 NoCredentialsError, # re-export: keep for import-fallback contract/tests
37 _extract_embeddings,
38 _map_bedrock_exception,
39)
41__all__ = [
42 "BotoCoreError",
43 "ClientError",
44 "EndpointConnectionError",
45 "NoCredentialsError",
46]
48logger = LoggingConfig.get_logger(__name__)
50DEFAULT_VECTOR_SIZES: dict[str, int] = {
51 "amazon.titan-embed-text-v2:0": 1024,
52 "amazon.titan-embed-text-v1": 1536,
53}
56class BedrockEmbeddings(EmbeddingsClient):
57 """Embeddings client for Bedrock Titan models, invoking one API call per input."""
59 MAX_BATCH_SIZE = 1000
61 def __init__(
62 self,
63 client: Any,
64 model_id: str,
65 expected_vector_size: int | None = None,
66 *,
67 provisioned_throughput_arn: str | None = None,
68 provider_label: str = "bedrock",
69 concurrency: int = 8,
70 ):
71 self._client = client
72 self._model_id = model_id
73 self._expected_vector_size = expected_vector_size
74 self._provisioned_throughput_arn = provisioned_throughput_arn
75 self._provider_label = provider_label
76 self._concurrency = concurrency
77 if self._concurrency < 1:
78 raise InvalidRequestError(
79 "Bedrock embeddings 'concurrency' must be a positive integer"
80 )
81 self._semaphore = asyncio.Semaphore(self._concurrency)
83 def _build_invoke_kwargs(self, text: str) -> dict[str, Any]:
84 payload_body = {"inputText": text}
86 if (
87 self._model_id.startswith("amazon.titan-embed-text-v2")
88 and self._expected_vector_size is not None
89 ):
90 if (
91 self._expected_vector_size
92 != DEFAULT_VECTOR_SIZES["amazon.titan-embed-text-v2:0"]
93 ):
94 payload_body["dimensions"] = self._expected_vector_size
96 invoke_kwargs: dict[str, Any] = {
97 "modelId": self._provisioned_throughput_arn or self._model_id,
98 "contentType": "application/json",
99 "accept": "application/json",
100 "body": json.dumps(payload_body),
101 }
103 return invoke_kwargs
105 def _read_response_body(self, response: Any) -> str:
106 body_data = (
107 response.get("body")
108 if isinstance(response, dict)
109 else getattr(response, "body", None)
110 )
112 if body_data is None:
113 raise ServerError("Bedrock response body missing")
115 if hasattr(body_data, "read"):
116 body_bytes = body_data.read()
117 else:
118 body_bytes = body_data
120 if isinstance(body_bytes, bytes):
121 return body_bytes.decode("utf-8")
123 if isinstance(body_bytes, str):
124 return body_bytes
126 raise ServerError("Bedrock response body is not bytes or string")
128 def _parse_single_embedding(
129 self,
130 raw_text: str,
131 ) -> list[float]:
132 payload = json.loads(raw_text)
133 embeddings = _extract_embeddings(payload)
135 if len(embeddings) != 1:
136 raise ServerError(
137 "Bedrock single request must return exactly one embedding"
138 )
140 return embeddings[0]
142 def _validate_vector(
143 self,
144 vector: list[float],
145 ) -> None:
146 if (
147 self._expected_vector_size is not None
148 and len(vector) != self._expected_vector_size
149 ):
150 raise ServerError(
151 f"Bedrock returned embedding vector with unexpected dimension: "
152 f"expected {self._expected_vector_size}, got {len(vector)}"
153 )
155 async def _invoke_single(
156 self,
157 text: str,
158 ) -> list[float]:
159 invoke_kwargs = self._build_invoke_kwargs(text)
160 started = datetime.now(UTC)
162 try:
163 response = await asyncio.to_thread(
164 self._client.invoke_model,
165 **invoke_kwargs,
166 )
168 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000)
170 logger.info(
171 "LLM request",
172 provider=self._provider_label,
173 operation="embeddings",
174 model=self._model_id,
175 latency_ms=duration_ms,
176 inputs=1,
177 )
179 raw_text = self._read_response_body(response)
180 vector = self._parse_single_embedding(raw_text)
181 self._validate_vector(vector)
183 return vector
185 except LLMError:
186 raise
188 except Exception as exc:
189 mapped = _map_bedrock_exception(exc)
191 logger.warning(
192 "LLM error",
193 provider=self._provider_label,
194 operation="embeddings",
195 model=self._model_id,
196 error=type(exc).__name__,
197 )
199 raise mapped from exc
201 async def embed(self, inputs: list[str]) -> list[list[float]]:
202 if self._client is None:
203 raise NotImplementedError("Bedrock client not available")
205 if not inputs:
206 return []
208 if len(inputs) > self.MAX_BATCH_SIZE:
209 raise InvalidRequestError(
210 f"Bedrock embedding batch size cannot exceed {self.MAX_BATCH_SIZE}"
211 )
213 async def _one(text: str) -> list[float]:
214 async with self._semaphore:
215 return await self._invoke_single(text)
217 return await asyncio.gather(*[_one(text) for text in inputs])
220class _BedrockChat(ChatClient):
221 async def chat(
222 self, messages: list[dict[str, Any]], **kwargs: Any
223 ) -> dict[str, Any]:
224 raise NotImplementedError("Bedrock chat is not implemented")
227class BedrockProvider(LLMProvider):
228 """LLM provider wrapper for AWS Bedrock Titan embedding models."""
230 SUPPORTED_MODELS = frozenset(
231 {
232 "amazon.titan-embed-text-v2:0",
233 "amazon.titan-embed-text-v1",
234 # "cohere.embed-english-v3",
235 }
236 )
238 def __init__(
239 self,
240 settings: LLMSettings,
241 *,
242 client: Any = None,
243 ):
244 self._settings = settings
245 provider_options = settings.provider_options or {}
247 model_id = provider_options.get("model_id") or settings.models.get("embeddings")
249 self._model_id = str(model_id) if model_id is not None else ""
251 self._aws_region = provider_options.get("aws_region")
253 self._provisioned_throughput_arn = provider_options.get(
254 "provisioned_throughput_arn"
255 )
257 try:
258 self._concurrency = int(provider_options.get("concurrency", 8))
259 except (TypeError, ValueError) as exc:
260 raise InvalidRequestError(
261 "Bedrock provider 'concurrency' must be a positive integer"
262 ) from exc
264 if self._concurrency < 1:
265 raise InvalidRequestError(
266 "Bedrock provider 'concurrency' must be a positive integer"
267 )
269 if not self._model_id:
270 raise InvalidRequestError(
271 "Bedrock provider requires 'model_id' in "
272 "llm.provider_options or llm.models.embeddings"
273 )
275 if self._model_id not in self.SUPPORTED_MODELS:
276 raise InvalidRequestError(
277 "Bedrock provider only supports "
278 f"{sorted(self.SUPPORTED_MODELS)}, "
279 f"got {self._model_id!r}"
280 )
282 self._vector_size = (
283 settings.embeddings.vector_size
284 if settings.embeddings.vector_size is not None
285 else DEFAULT_VECTOR_SIZES.get(self._model_id, 1024)
286 )
288 self._client = (
289 client
290 if client is not None
291 else (
292 boto3.client(
293 "bedrock-runtime",
294 region_name=self._aws_region,
295 )
296 if boto3 is not None
297 else None
298 )
299 )
301 def embeddings(self) -> EmbeddingsClient:
302 return BedrockEmbeddings(
303 self._client,
304 self._model_id,
305 self._vector_size,
306 provisioned_throughput_arn=self._provisioned_throughput_arn,
307 provider_label="bedrock",
308 concurrency=self._concurrency,
309 )
311 def chat(self) -> ChatClient:
312 raise NotImplementedError("Bedrock provider does not support chat()")
314 def tokenizer(self) -> TokenCounter:
315 return BedrockTokenizer()