Coverage for src / qdrant_loader_core / llm / providers / bedrock.py: 84%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:34 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import json 

5from datetime import UTC, datetime 

6from typing import Any 

7 

8try: 

9 import boto3 

10except ImportError: 

11 boto3 = None # type: ignore[assignment] 

12 

13from ...logging import LoggingConfig 

14from ..errors import ( 

15 InvalidRequestError, 

16 LLMError, 

17 ServerError, 

18) 

19from ..settings import LLMSettings 

20from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter 

21 

22# Re-exported exception types are intentionally kept here. 

23# 

24# They are used by: 

25# 1. Runtime exception mapping in _map_bedrock_exception() 

26# 2. Import fallback behavior when boto3 / botocore is unavailable 

27# 3. Unit tests that verify the module still exposes these symbols 

28# 

29# Do not remove these imports even if they appear unused, as they are 

30# part of the module contract and support graceful degradation. 

31from .bedrock_utils import ( 

32 BedrockTokenizer, 

33 BotoCoreError, # re-export: keep for import-fallback contract/tests 

34 ClientError, # re-export: keep for import-fallback contract/tests 

35 EndpointConnectionError, # re-export: keep for import-fallback contract/tests 

36 NoCredentialsError, # re-export: keep for import-fallback contract/tests 

37 _extract_embeddings, 

38 _map_bedrock_exception, 

39) 

40 

41__all__ = [ 

42 "BotoCoreError", 

43 "ClientError", 

44 "EndpointConnectionError", 

45 "NoCredentialsError", 

46] 

47 

48logger = LoggingConfig.get_logger(__name__) 

49 

50DEFAULT_VECTOR_SIZES: dict[str, int] = { 

51 "amazon.titan-embed-text-v2:0": 1024, 

52 "amazon.titan-embed-text-v1": 1536, 

53} 

54 

55 

56class BedrockEmbeddings(EmbeddingsClient): 

57 """Embeddings client for Bedrock Titan models, invoking one API call per input.""" 

58 

59 MAX_BATCH_SIZE = 1000 

60 

61 def __init__( 

62 self, 

63 client: Any, 

64 model_id: str, 

65 expected_vector_size: int | None = None, 

66 *, 

67 provisioned_throughput_arn: str | None = None, 

68 provider_label: str = "bedrock", 

69 concurrency: int = 8, 

70 ): 

71 self._client = client 

72 self._model_id = model_id 

73 self._expected_vector_size = expected_vector_size 

74 self._provisioned_throughput_arn = provisioned_throughput_arn 

75 self._provider_label = provider_label 

76 self._concurrency = concurrency 

77 if self._concurrency < 1: 

78 raise InvalidRequestError( 

79 "Bedrock embeddings 'concurrency' must be a positive integer" 

80 ) 

81 self._semaphore = asyncio.Semaphore(self._concurrency) 

82 

83 def _build_invoke_kwargs(self, text: str) -> dict[str, Any]: 

84 payload_body = {"inputText": text} 

85 

86 if ( 

87 self._model_id.startswith("amazon.titan-embed-text-v2") 

88 and self._expected_vector_size is not None 

89 ): 

90 if ( 

91 self._expected_vector_size 

92 != DEFAULT_VECTOR_SIZES["amazon.titan-embed-text-v2:0"] 

93 ): 

94 payload_body["dimensions"] = self._expected_vector_size 

95 

96 invoke_kwargs: dict[str, Any] = { 

97 "modelId": self._provisioned_throughput_arn or self._model_id, 

98 "contentType": "application/json", 

99 "accept": "application/json", 

100 "body": json.dumps(payload_body), 

101 } 

102 

103 return invoke_kwargs 

104 

105 def _read_response_body(self, response: Any) -> str: 

106 body_data = ( 

107 response.get("body") 

108 if isinstance(response, dict) 

109 else getattr(response, "body", None) 

110 ) 

111 

112 if body_data is None: 

113 raise ServerError("Bedrock response body missing") 

114 

115 if hasattr(body_data, "read"): 

116 body_bytes = body_data.read() 

117 else: 

118 body_bytes = body_data 

119 

120 if isinstance(body_bytes, bytes): 

121 return body_bytes.decode("utf-8") 

122 

123 if isinstance(body_bytes, str): 

124 return body_bytes 

125 

126 raise ServerError("Bedrock response body is not bytes or string") 

127 

128 def _parse_single_embedding( 

129 self, 

130 raw_text: str, 

131 ) -> list[float]: 

132 payload = json.loads(raw_text) 

133 embeddings = _extract_embeddings(payload) 

134 

135 if len(embeddings) != 1: 

136 raise ServerError( 

137 "Bedrock single request must return exactly one embedding" 

138 ) 

139 

140 return embeddings[0] 

141 

142 def _validate_vector( 

143 self, 

144 vector: list[float], 

145 ) -> None: 

146 if ( 

147 self._expected_vector_size is not None 

148 and len(vector) != self._expected_vector_size 

149 ): 

150 raise ServerError( 

151 f"Bedrock returned embedding vector with unexpected dimension: " 

152 f"expected {self._expected_vector_size}, got {len(vector)}" 

153 ) 

154 

155 async def _invoke_single( 

156 self, 

157 text: str, 

158 ) -> list[float]: 

159 invoke_kwargs = self._build_invoke_kwargs(text) 

160 started = datetime.now(UTC) 

161 

162 try: 

163 response = await asyncio.to_thread( 

164 self._client.invoke_model, 

165 **invoke_kwargs, 

166 ) 

167 

168 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000) 

169 

170 logger.info( 

171 "LLM request", 

172 provider=self._provider_label, 

173 operation="embeddings", 

174 model=self._model_id, 

175 latency_ms=duration_ms, 

176 inputs=1, 

177 ) 

178 

179 raw_text = self._read_response_body(response) 

180 vector = self._parse_single_embedding(raw_text) 

181 self._validate_vector(vector) 

182 

183 return vector 

184 

185 except LLMError: 

186 raise 

187 

188 except Exception as exc: 

189 mapped = _map_bedrock_exception(exc) 

190 

191 logger.warning( 

192 "LLM error", 

193 provider=self._provider_label, 

194 operation="embeddings", 

195 model=self._model_id, 

196 error=type(exc).__name__, 

197 ) 

198 

199 raise mapped from exc 

200 

201 async def embed(self, inputs: list[str]) -> list[list[float]]: 

202 if self._client is None: 

203 raise NotImplementedError("Bedrock client not available") 

204 

205 if not inputs: 

206 return [] 

207 

208 if len(inputs) > self.MAX_BATCH_SIZE: 

209 raise InvalidRequestError( 

210 f"Bedrock embedding batch size cannot exceed {self.MAX_BATCH_SIZE}" 

211 ) 

212 

213 async def _one(text: str) -> list[float]: 

214 async with self._semaphore: 

215 return await self._invoke_single(text) 

216 

217 return await asyncio.gather(*[_one(text) for text in inputs]) 

218 

219 

220class _BedrockChat(ChatClient): 

221 async def chat( 

222 self, messages: list[dict[str, Any]], **kwargs: Any 

223 ) -> dict[str, Any]: 

224 raise NotImplementedError("Bedrock chat is not implemented") 

225 

226 

227class BedrockProvider(LLMProvider): 

228 """LLM provider wrapper for AWS Bedrock Titan embedding models.""" 

229 

230 SUPPORTED_MODELS = frozenset( 

231 { 

232 "amazon.titan-embed-text-v2:0", 

233 "amazon.titan-embed-text-v1", 

234 # "cohere.embed-english-v3", 

235 } 

236 ) 

237 

238 def __init__( 

239 self, 

240 settings: LLMSettings, 

241 *, 

242 client: Any = None, 

243 ): 

244 self._settings = settings 

245 provider_options = settings.provider_options or {} 

246 

247 model_id = provider_options.get("model_id") or settings.models.get("embeddings") 

248 

249 self._model_id = str(model_id) if model_id is not None else "" 

250 

251 self._aws_region = provider_options.get("aws_region") 

252 

253 self._provisioned_throughput_arn = provider_options.get( 

254 "provisioned_throughput_arn" 

255 ) 

256 

257 try: 

258 self._concurrency = int(provider_options.get("concurrency", 8)) 

259 except (TypeError, ValueError) as exc: 

260 raise InvalidRequestError( 

261 "Bedrock provider 'concurrency' must be a positive integer" 

262 ) from exc 

263 

264 if self._concurrency < 1: 

265 raise InvalidRequestError( 

266 "Bedrock provider 'concurrency' must be a positive integer" 

267 ) 

268 

269 if not self._model_id: 

270 raise InvalidRequestError( 

271 "Bedrock provider requires 'model_id' in " 

272 "llm.provider_options or llm.models.embeddings" 

273 ) 

274 

275 if self._model_id not in self.SUPPORTED_MODELS: 

276 raise InvalidRequestError( 

277 "Bedrock provider only supports " 

278 f"{sorted(self.SUPPORTED_MODELS)}, " 

279 f"got {self._model_id!r}" 

280 ) 

281 

282 self._vector_size = ( 

283 settings.embeddings.vector_size 

284 if settings.embeddings.vector_size is not None 

285 else DEFAULT_VECTOR_SIZES.get(self._model_id, 1024) 

286 ) 

287 

288 self._client = ( 

289 client 

290 if client is not None 

291 else ( 

292 boto3.client( 

293 "bedrock-runtime", 

294 region_name=self._aws_region, 

295 ) 

296 if boto3 is not None 

297 else None 

298 ) 

299 ) 

300 

301 def embeddings(self) -> EmbeddingsClient: 

302 return BedrockEmbeddings( 

303 self._client, 

304 self._model_id, 

305 self._vector_size, 

306 provisioned_throughput_arn=self._provisioned_throughput_arn, 

307 provider_label="bedrock", 

308 concurrency=self._concurrency, 

309 ) 

310 

311 def chat(self) -> ChatClient: 

312 raise NotImplementedError("Bedrock provider does not support chat()") 

313 

314 def tokenizer(self) -> TokenCounter: 

315 return BedrockTokenizer()