Coverage for src/qdrant_loader/core/embedding/embedding_service.py: 75%

162 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:05 +0000

1import asyncio 

2import logging 

3import time 

4from collections.abc import Sequence 

5from importlib import import_module 

6 

7import requests 

8import tiktoken 

9 

10from qdrant_loader.config import Settings 

11from qdrant_loader.core.document import Document 

12from qdrant_loader.utils.logging import LoggingConfig 

13 

14logger = LoggingConfig.get_logger(__name__) 

15 

16 

17class EmbeddingService: 

18 """Service for generating embeddings using provider-agnostic API (via core).""" 

19 

20 def __init__(self, settings: Settings): 

21 """Initialize the embedding service. 

22 

23 Args: 

24 settings: The application settings containing API key and endpoint. 

25 """ 

26 self.settings = settings 

27 # Build LLM settings from global config and create provider 

28 llm_settings = settings.llm_settings 

29 factory_mod = import_module("qdrant_loader_core.llm.factory") 

30 create_provider = factory_mod.create_provider 

31 self.provider = create_provider(llm_settings) 

32 self.model = llm_settings.models.get( 

33 "embeddings", settings.global_config.embedding.model 

34 ) 

35 self.tokenizer = ( 

36 llm_settings.tokenizer or settings.global_config.embedding.tokenizer 

37 ) 

38 self.batch_size = settings.global_config.embedding.batch_size 

39 

40 # Initialize tokenizer based on configuration 

41 if self.tokenizer == "none": 

42 self.encoding = None 

43 else: 

44 try: 

45 self.encoding = tiktoken.get_encoding(self.tokenizer) 

46 except Exception as e: 

47 logger.warning( 

48 "Failed to initialize tokenizer, falling back to simple character counting", 

49 error=str(e), 

50 tokenizer=self.tokenizer, 

51 ) 

52 self.encoding = None 

53 

54 self.last_request_time = 0 

55 self.min_request_interval = 0.5 # 500ms between requests 

56 

57 # Retry configuration for network resilience 

58 self.max_retries = 3 

59 self.base_retry_delay = 1.0 # Start with 1 second 

60 self.max_retry_delay = 30.0 # Cap at 30 seconds 

61 

62 async def _apply_rate_limit(self): 

63 """Apply rate limiting between API requests.""" 

64 current_time = time.time() 

65 time_since_last_request = current_time - self.last_request_time 

66 if time_since_last_request < self.min_request_interval: 

67 await asyncio.sleep(self.min_request_interval - time_since_last_request) 

68 self.last_request_time = time.time() 

69 

70 async def _retry_with_backoff(self, operation, operation_name: str, **kwargs): 

71 """Execute an operation with exponential backoff retry logic. 

72 

73 Args: 

74 operation: The async operation to retry 

75 operation_name: Name of the operation for logging 

76 **kwargs: Additional arguments passed to the operation 

77 

78 Returns: 

79 The result of the successful operation 

80 

81 Raises: 

82 The last exception if all retries fail 

83 """ 

84 last_exception = None 

85 

86 for attempt in range(self.max_retries + 1): # +1 for initial attempt 

87 try: 

88 if attempt > 0: 

89 # Calculate exponential backoff delay 

90 delay = min( 

91 self.base_retry_delay * (2 ** (attempt - 1)), 

92 self.max_retry_delay, 

93 ) 

94 logger.warning( 

95 f"Retrying {operation_name} after network error", 

96 attempt=attempt, 

97 max_retries=self.max_retries, 

98 delay_seconds=delay, 

99 last_error=str(last_exception) if last_exception else None, 

100 ) 

101 await asyncio.sleep(delay) 

102 

103 # Execute the operation 

104 result = await operation(**kwargs) 

105 

106 if attempt > 0: 

107 logger.info( 

108 f"Successfully recovered {operation_name} after retries", 

109 successful_attempt=attempt + 1, 

110 total_attempts=attempt + 1, 

111 ) 

112 

113 return result 

114 

115 except ( 

116 TimeoutError, 

117 requests.exceptions.Timeout, 

118 requests.exceptions.ConnectionError, 

119 requests.exceptions.HTTPError, 

120 ConnectionError, 

121 OSError, 

122 ) as e: 

123 last_exception = e 

124 

125 if attempt == self.max_retries: 

126 logger.error( 

127 f"All retry attempts failed for {operation_name}", 

128 total_attempts=attempt + 1, 

129 final_error=str(e), 

130 error_type=type(e).__name__, 

131 ) 

132 raise 

133 

134 logger.warning( 

135 f"Network error in {operation_name}, will retry", 

136 attempt=attempt + 1, 

137 max_retries=self.max_retries, 

138 error=str(e), 

139 error_type=type(e).__name__, 

140 ) 

141 

142 except Exception as e: 

143 # For non-network errors, don't retry 

144 logger.error( 

145 f"Non-retryable error in {operation_name}", 

146 error=str(e), 

147 error_type=type(e).__name__, 

148 ) 

149 raise 

150 

151 # This should never be reached, but just in case 

152 if last_exception: 

153 raise last_exception 

154 raise RuntimeError(f"Unexpected error in retry logic for {operation_name}") 

155 

156 async def get_embeddings( 

157 self, texts: Sequence[str | Document] 

158 ) -> list[list[float]]: 

159 """Get embeddings for a list of texts.""" 

160 if not texts: 

161 return [] 

162 

163 # Extract content if texts are Document objects 

164 contents = [ 

165 text.content if isinstance(text, Document) else text for text in texts 

166 ] 

167 

168 # Filter out empty, None, or invalid content 

169 valid_contents = [] 

170 valid_indices = [] 

171 for i, content in enumerate(contents): 

172 if content and isinstance(content, str) and content.strip(): 

173 valid_contents.append(content.strip()) 

174 valid_indices.append(i) 

175 else: 

176 logger.warning( 

177 f"Skipping invalid content at index {i}: {repr(content)}" 

178 ) 

179 

180 if not valid_contents: 

181 logger.warning( 

182 "No valid content found in batch, returning empty embeddings" 

183 ) 

184 return [] 

185 

186 logger.debug( 

187 "Starting batch embedding process", 

188 total_texts=len(contents), 

189 valid_texts=len(valid_contents), 

190 filtered_out=len(contents) - len(valid_contents), 

191 ) 

192 

193 # Validate and split content based on token limits 

194 # Use configurable token limits from settings 

195 MAX_TOKENS_PER_REQUEST = ( 

196 self.settings.global_config.embedding.max_tokens_per_request 

197 ) 

198 MAX_TOKENS_PER_CHUNK = ( 

199 self.settings.global_config.embedding.max_tokens_per_chunk 

200 ) 

201 

202 validated_contents = [] 

203 truncated_count = 0 

204 for content in valid_contents: 

205 token_count = self.count_tokens(content) 

206 if token_count > MAX_TOKENS_PER_CHUNK: 

207 truncated_count += 1 

208 logger.warning( 

209 "Content exceeds maximum token limit, truncating", 

210 content_length=len(content), 

211 token_count=token_count, 

212 max_tokens=MAX_TOKENS_PER_CHUNK, 

213 ) 

214 # Truncate content to fit within token limit 

215 if self.encoding is not None: 

216 # Use tokenizer to truncate precisely 

217 tokens = self.encoding.encode(content) 

218 truncated_tokens = tokens[:MAX_TOKENS_PER_CHUNK] 

219 truncated_content = self.encoding.decode(truncated_tokens) 

220 validated_contents.append(truncated_content) 

221 else: 

222 # Fallback to character-based truncation (rough estimate) 

223 # Assume ~4 characters per token on average 

224 max_chars = MAX_TOKENS_PER_CHUNK * 4 

225 validated_contents.append(content[:max_chars]) 

226 else: 

227 validated_contents.append(content) 

228 

229 if truncated_count > 0: 

230 logger.info( 

231 f"⚠️ Truncated {truncated_count} content items due to token limits. You might want to adjust chunk size and/or max tokens settings in config.yaml" 

232 ) 

233 

234 # Create smart batches that respect token limits 

235 embeddings = [] 

236 current_batch = [] 

237 current_batch_tokens = 0 

238 batch_count = 0 

239 

240 for content in validated_contents: 

241 content_tokens = self.count_tokens(content) 

242 

243 # Check if adding this content would exceed the token limit 

244 if current_batch and ( 

245 current_batch_tokens + content_tokens > MAX_TOKENS_PER_REQUEST 

246 ): 

247 # Process current batch 

248 batch_count += 1 

249 batch_embeddings = await self._process_batch(current_batch) 

250 embeddings.extend(batch_embeddings) 

251 

252 # Start new batch 

253 current_batch = [content] 

254 current_batch_tokens = content_tokens 

255 else: 

256 # Add to current batch 

257 current_batch.append(content) 

258 current_batch_tokens += content_tokens 

259 

260 # Process final batch if it exists 

261 if current_batch: 

262 batch_count += 1 

263 batch_embeddings = await self._process_batch(current_batch) 

264 embeddings.extend(batch_embeddings) 

265 

266 logger.info( 

267 f"🔗 Generated embeddings: {len(embeddings)} items in {batch_count} batches" 

268 ) 

269 return embeddings 

270 

271 async def _process_batch(self, batch: list[str]) -> list[list[float]]: 

272 """Process a single batch of content for embeddings. 

273 

274 Args: 

275 batch: List of content strings to embed 

276 

277 Returns: 

278 List of embedding vectors 

279 """ 

280 if not batch: 

281 return [] 

282 

283 batch_num = getattr(self, "_batch_counter", 0) + 1 

284 self._batch_counter = batch_num 

285 

286 # Optimized: Only calculate tokens for debug when debug logging is enabled 

287 if logging.getLogger().isEnabledFor(logging.DEBUG): 

288 logger.debug( 

289 "Processing embedding batch", 

290 batch_num=batch_num, 

291 batch_size=len(batch), 

292 total_tokens=sum(self.count_tokens(content) for content in batch), 

293 ) 

294 

295 await self._apply_rate_limit() 

296 

297 # Use retry logic for network resilience 

298 return await self._retry_with_backoff( 

299 self._execute_embedding_request, 

300 f"embedding batch {batch_num}", 

301 batch=batch, 

302 batch_num=batch_num, 

303 ) 

304 

305 async def _execute_embedding_request( 

306 self, batch: list[str], batch_num: int 

307 ) -> list[list[float]]: 

308 """Execute the actual embedding request (used by retry logic). 

309 

310 Args: 

311 batch: List of content strings to embed 

312 batch_num: Batch number for logging 

313 

314 Returns: 

315 List of embedding vectors 

316 """ 

317 try: 

318 # Use core provider for embeddings 

319 embeddings_client = self.provider.embeddings() 

320 batch_embeddings = await embeddings_client.embed(batch) 

321 

322 logger.debug( 

323 "Completed batch processing", 

324 batch_num=batch_num, 

325 processed_embeddings=len(batch_embeddings), 

326 ) 

327 

328 return batch_embeddings 

329 

330 except Exception as e: 

331 logger.debug( 

332 "Embedding request failed", 

333 batch_num=batch_num, 

334 error=str(e), 

335 error_type=type(e).__name__, 

336 ) 

337 raise # Let the retry logic handle it 

338 

339 async def get_embedding(self, text: str) -> list[float]: 

340 """Get embedding for a single text.""" 

341 # Validate input 

342 if not text or not isinstance(text, str) or not text.strip(): 

343 logger.warning(f"Invalid text for embedding: {repr(text)}") 

344 raise ValueError( 

345 "Invalid text for embedding: text must be a non-empty string" 

346 ) 

347 

348 clean_text = text.strip() 

349 

350 # Use retry logic for network resilience 

351 return await self._retry_with_backoff( 

352 self._execute_single_embedding_request, "single embedding", text=clean_text 

353 ) 

354 

355 async def _execute_single_embedding_request(self, text: str) -> list[float]: 

356 """Execute a single embedding request (used by retry logic). 

357 

358 Args: 

359 text: The text to embed 

360 

361 Returns: 

362 The embedding vector 

363 """ 

364 try: 

365 await self._apply_rate_limit() 

366 embeddings_client = self.provider.embeddings() 

367 vectors = await embeddings_client.embed([text]) 

368 return vectors[0] 

369 except Exception as e: 

370 logger.debug( 

371 "Single embedding request failed", 

372 error=str(e), 

373 error_type=type(e).__name__, 

374 ) 

375 raise # Let the retry logic handle it 

376 

377 def count_tokens(self, text: str) -> int: 

378 """Count the number of tokens in a text string.""" 

379 if self.encoding is None: 

380 # Fallback to character count if no tokenizer is available 

381 return len(text) 

382 return len(self.encoding.encode(text)) 

383 

384 def count_tokens_batch(self, texts: list[str]) -> list[int]: 

385 """Count the number of tokens in a list of text strings.""" 

386 return [self.count_tokens(text) for text in texts] 

387 

388 def get_embedding_dimension(self) -> int: 

389 """Get the dimension of the embedding vectors.""" 

390 # Prefer vector size from unified settings when available 

391 dimension = ( 

392 self.settings.llm_settings.embeddings.vector_size 

393 or self.settings.global_config.embedding.vector_size 

394 ) 

395 if not dimension: 

396 logger.warning( 

397 "Embedding dimension not set in config; using 1536 (deprecated default). Set global.llm.embeddings.vector_size." 

398 ) 

399 return 1536 

400 return int(dimension)