Coverage for src/qdrant_loader/core/embedding/embedding_service.py: 74%

175 statements  

« prev     ^ index     » next       coverage.py v7.10.0, created at 2025-07-25 11:39 +0000

1import asyncio 

2import time 

3from collections.abc import Sequence 

4 

5import requests 

6import tiktoken 

7from openai import OpenAI 

8 

9from qdrant_loader.config import Settings 

10from qdrant_loader.core.document import Document 

11from qdrant_loader.utils.logging import LoggingConfig 

12 

13logger = LoggingConfig.get_logger(__name__) 

14 

15 

16class EmbeddingService: 

17 """Service for generating embeddings using OpenAI's API or local service.""" 

18 

19 def __init__(self, settings: Settings): 

20 """Initialize the embedding service. 

21 

22 Args: 

23 settings: The application settings containing API key and endpoint. 

24 """ 

25 self.settings = settings 

26 self.endpoint = settings.global_config.embedding.endpoint.rstrip("/") 

27 self.model = settings.global_config.embedding.model 

28 self.tokenizer = settings.global_config.embedding.tokenizer 

29 self.batch_size = settings.global_config.embedding.batch_size 

30 

31 # Initialize client based on endpoint 

32 if "https://api.openai.com/v1" == self.endpoint: 

33 self.client = OpenAI( 

34 api_key=settings.global_config.embedding.api_key, base_url=self.endpoint 

35 ) 

36 self.use_openai = True 

37 else: 

38 self.client = None 

39 self.use_openai = False 

40 

41 # Initialize tokenizer based on configuration 

42 if self.tokenizer == "none": 

43 self.encoding = None 

44 else: 

45 try: 

46 self.encoding = tiktoken.get_encoding(self.tokenizer) 

47 except Exception as e: 

48 logger.warning( 

49 "Failed to initialize tokenizer, falling back to simple character counting", 

50 error=str(e), 

51 tokenizer=self.tokenizer, 

52 ) 

53 self.encoding = None 

54 

55 self.last_request_time = 0 

56 self.min_request_interval = 0.5 # 500ms between requests 

57 

58 # Retry configuration for network resilience 

59 self.max_retries = 3 

60 self.base_retry_delay = 1.0 # Start with 1 second 

61 self.max_retry_delay = 30.0 # Cap at 30 seconds 

62 

63 async def _apply_rate_limit(self): 

64 """Apply rate limiting between API requests.""" 

65 current_time = time.time() 

66 time_since_last_request = current_time - self.last_request_time 

67 if time_since_last_request < self.min_request_interval: 

68 await asyncio.sleep(self.min_request_interval - time_since_last_request) 

69 self.last_request_time = time.time() 

70 

71 async def _retry_with_backoff(self, operation, operation_name: str, **kwargs): 

72 """Execute an operation with exponential backoff retry logic. 

73 

74 Args: 

75 operation: The async operation to retry 

76 operation_name: Name of the operation for logging 

77 **kwargs: Additional arguments passed to the operation 

78 

79 Returns: 

80 The result of the successful operation 

81 

82 Raises: 

83 The last exception if all retries fail 

84 """ 

85 last_exception = None 

86 

87 for attempt in range(self.max_retries + 1): # +1 for initial attempt 

88 try: 

89 if attempt > 0: 

90 # Calculate exponential backoff delay 

91 delay = min( 

92 self.base_retry_delay * (2 ** (attempt - 1)), 

93 self.max_retry_delay, 

94 ) 

95 logger.warning( 

96 f"Retrying {operation_name} after network error", 

97 attempt=attempt, 

98 max_retries=self.max_retries, 

99 delay_seconds=delay, 

100 last_error=str(last_exception) if last_exception else None, 

101 ) 

102 await asyncio.sleep(delay) 

103 

104 # Execute the operation 

105 result = await operation(**kwargs) 

106 

107 if attempt > 0: 

108 logger.info( 

109 f"Successfully recovered {operation_name} after retries", 

110 successful_attempt=attempt + 1, 

111 total_attempts=attempt + 1, 

112 ) 

113 

114 return result 

115 

116 except (TimeoutError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.HTTPError, ConnectionError, OSError) as e: 

117 last_exception = e 

118 

119 if attempt == self.max_retries: 

120 logger.error( 

121 f"All retry attempts failed for {operation_name}", 

122 total_attempts=attempt + 1, 

123 final_error=str(e), 

124 error_type=type(e).__name__, 

125 ) 

126 raise 

127 

128 logger.warning( 

129 f"Network error in {operation_name}, will retry", 

130 attempt=attempt + 1, 

131 max_retries=self.max_retries, 

132 error=str(e), 

133 error_type=type(e).__name__, 

134 ) 

135 

136 except Exception as e: 

137 # For non-network errors, don't retry 

138 logger.error( 

139 f"Non-retryable error in {operation_name}", 

140 error=str(e), 

141 error_type=type(e).__name__, 

142 ) 

143 raise 

144 

145 # This should never be reached, but just in case 

146 if last_exception: 

147 raise last_exception 

148 raise RuntimeError(f"Unexpected error in retry logic for {operation_name}") 

149 

150 async def get_embeddings( 

151 self, texts: Sequence[str | Document] 

152 ) -> list[list[float]]: 

153 """Get embeddings for a list of texts.""" 

154 if not texts: 

155 return [] 

156 

157 # Extract content if texts are Document objects 

158 contents = [ 

159 text.content if isinstance(text, Document) else text for text in texts 

160 ] 

161 

162 # Filter out empty, None, or invalid content 

163 valid_contents = [] 

164 valid_indices = [] 

165 for i, content in enumerate(contents): 

166 if content and isinstance(content, str) and content.strip(): 

167 valid_contents.append(content.strip()) 

168 valid_indices.append(i) 

169 else: 

170 logger.warning( 

171 f"Skipping invalid content at index {i}: {repr(content)}" 

172 ) 

173 

174 if not valid_contents: 

175 logger.warning( 

176 "No valid content found in batch, returning empty embeddings" 

177 ) 

178 return [] 

179 

180 logger.debug( 

181 "Starting batch embedding process", 

182 total_texts=len(contents), 

183 valid_texts=len(valid_contents), 

184 filtered_out=len(contents) - len(valid_contents), 

185 ) 

186 

187 # Validate and split content based on token limits 

188 # Use configurable token limits from settings 

189 MAX_TOKENS_PER_REQUEST = ( 

190 self.settings.global_config.embedding.max_tokens_per_request 

191 ) 

192 MAX_TOKENS_PER_CHUNK = ( 

193 self.settings.global_config.embedding.max_tokens_per_chunk 

194 ) 

195 

196 validated_contents = [] 

197 truncated_count = 0 

198 for content in valid_contents: 

199 token_count = self.count_tokens(content) 

200 if token_count > MAX_TOKENS_PER_CHUNK: 

201 truncated_count += 1 

202 logger.warning( 

203 "Content exceeds maximum token limit, truncating", 

204 content_length=len(content), 

205 token_count=token_count, 

206 max_tokens=MAX_TOKENS_PER_CHUNK, 

207 ) 

208 # Truncate content to fit within token limit 

209 if self.encoding is not None: 

210 # Use tokenizer to truncate precisely 

211 tokens = self.encoding.encode(content) 

212 truncated_tokens = tokens[:MAX_TOKENS_PER_CHUNK] 

213 truncated_content = self.encoding.decode(truncated_tokens) 

214 validated_contents.append(truncated_content) 

215 else: 

216 # Fallback to character-based truncation (rough estimate) 

217 # Assume ~4 characters per token on average 

218 max_chars = MAX_TOKENS_PER_CHUNK * 4 

219 validated_contents.append(content[:max_chars]) 

220 else: 

221 validated_contents.append(content) 

222 

223 if truncated_count > 0: 

224 logger.info( 

225 f"⚠️ Truncated {truncated_count} content items due to token limits. You might want to adjust chunk size and/or max tokens settings in config.yaml" 

226 ) 

227 

228 # Create smart batches that respect token limits 

229 embeddings = [] 

230 current_batch = [] 

231 current_batch_tokens = 0 

232 batch_count = 0 

233 

234 for content in validated_contents: 

235 content_tokens = self.count_tokens(content) 

236 

237 # Check if adding this content would exceed the token limit 

238 if current_batch and ( 

239 current_batch_tokens + content_tokens > MAX_TOKENS_PER_REQUEST 

240 ): 

241 # Process current batch 

242 batch_count += 1 

243 batch_embeddings = await self._process_batch(current_batch) 

244 embeddings.extend(batch_embeddings) 

245 

246 # Start new batch 

247 current_batch = [content] 

248 current_batch_tokens = content_tokens 

249 else: 

250 # Add to current batch 

251 current_batch.append(content) 

252 current_batch_tokens += content_tokens 

253 

254 # Process final batch if it exists 

255 if current_batch: 

256 batch_count += 1 

257 batch_embeddings = await self._process_batch(current_batch) 

258 embeddings.extend(batch_embeddings) 

259 

260 logger.info( 

261 f"🔗 Generated embeddings: {len(embeddings)} items in {batch_count} batches" 

262 ) 

263 return embeddings 

264 

265 async def _process_batch(self, batch: list[str]) -> list[list[float]]: 

266 """Process a single batch of content for embeddings. 

267 

268 Args: 

269 batch: List of content strings to embed 

270 

271 Returns: 

272 List of embedding vectors 

273 """ 

274 if not batch: 

275 return [] 

276 

277 batch_num = getattr(self, "_batch_counter", 0) + 1 

278 self._batch_counter = batch_num 

279 

280 logger.debug( 

281 "Processing embedding batch", 

282 batch_num=batch_num, 

283 batch_size=len(batch), 

284 total_tokens=sum(self.count_tokens(content) for content in batch), 

285 ) 

286 

287 await self._apply_rate_limit() 

288 

289 # Use retry logic for network resilience 

290 return await self._retry_with_backoff( 

291 self._execute_embedding_request, 

292 f"embedding batch {batch_num}", 

293 batch=batch, 

294 batch_num=batch_num, 

295 ) 

296 

297 async def _execute_embedding_request( 

298 self, batch: list[str], batch_num: int 

299 ) -> list[list[float]]: 

300 """Execute the actual embedding request (used by retry logic). 

301 

302 Args: 

303 batch: List of content strings to embed 

304 batch_num: Batch number for logging 

305 

306 Returns: 

307 List of embedding vectors 

308 """ 

309 try: 

310 if self.use_openai and self.client is not None: 

311 logger.debug( 

312 "Getting batch embeddings from OpenAI", 

313 model=self.model, 

314 batch_num=batch_num, 

315 ) 

316 

317 # Use shorter timeout for initial attempts, let retry logic handle failures 

318 response = await asyncio.wait_for( 

319 asyncio.to_thread( 

320 self.client.embeddings.create, model=self.model, input=batch 

321 ), 

322 timeout=45.0, # Reduced from 90s for faster failure detection 

323 ) 

324 batch_embeddings = [embedding.embedding for embedding in response.data] 

325 

326 else: 

327 # Local service request 

328 logger.debug( 

329 "Getting batch embeddings from local service", 

330 model=self.model, 

331 endpoint=self.endpoint, 

332 batch_num=batch_num, 

333 ) 

334 

335 response = await asyncio.wait_for( 

336 asyncio.to_thread( 

337 requests.post, 

338 f"{self.endpoint}/embeddings", 

339 json={"input": batch, "model": self.model}, 

340 headers={"Content-Type": "application/json"}, 

341 timeout=30, # Reduced timeout for faster failure detection 

342 ), 

343 timeout=45.0, # Reduced from 90s 

344 ) 

345 response.raise_for_status() 

346 data = response.json() 

347 if "data" not in data or not data["data"]: 

348 raise ValueError( 

349 "Invalid response format from local embedding service" 

350 ) 

351 batch_embeddings = [item["embedding"] for item in data["data"]] 

352 

353 logger.debug( 

354 "Completed batch processing", 

355 batch_num=batch_num, 

356 processed_embeddings=len(batch_embeddings), 

357 ) 

358 

359 return batch_embeddings 

360 

361 except Exception as e: 

362 logger.debug( 

363 "Embedding request failed", 

364 batch_num=batch_num, 

365 error=str(e), 

366 error_type=type(e).__name__, 

367 ) 

368 raise # Let the retry logic handle it 

369 

370 async def get_embedding(self, text: str) -> list[float]: 

371 """Get embedding for a single text.""" 

372 # Validate input 

373 if not text or not isinstance(text, str) or not text.strip(): 

374 logger.warning(f"Invalid text for embedding: {repr(text)}") 

375 raise ValueError( 

376 "Invalid text for embedding: text must be a non-empty string" 

377 ) 

378 

379 clean_text = text.strip() 

380 

381 # Use retry logic for network resilience 

382 return await self._retry_with_backoff( 

383 self._execute_single_embedding_request, "single embedding", text=clean_text 

384 ) 

385 

386 async def _execute_single_embedding_request(self, text: str) -> list[float]: 

387 """Execute a single embedding request (used by retry logic). 

388 

389 Args: 

390 text: The text to embed 

391 

392 Returns: 

393 The embedding vector 

394 """ 

395 try: 

396 await self._apply_rate_limit() 

397 if self.use_openai and self.client is not None: 

398 logger.debug("Getting embedding from OpenAI", model=self.model) 

399 response = await asyncio.wait_for( 

400 asyncio.to_thread( 

401 self.client.embeddings.create, 

402 model=self.model, 

403 input=[text], # OpenAI API expects a list 

404 ), 

405 timeout=30.0, # Reduced timeout for faster failure detection 

406 ) 

407 return response.data[0].embedding 

408 else: 

409 # Local service request 

410 logger.debug( 

411 "Getting embedding from local service", 

412 model=self.model, 

413 endpoint=self.endpoint, 

414 ) 

415 response = await asyncio.wait_for( 

416 asyncio.to_thread( 

417 requests.post, 

418 f"{self.endpoint}/embeddings", 

419 json={"input": text, "model": self.model}, 

420 headers={"Content-Type": "application/json"}, 

421 timeout=15, # Reduced timeout 

422 ), 

423 timeout=30.0, # Reduced timeout 

424 ) 

425 response.raise_for_status() 

426 data = response.json() 

427 if "data" not in data or not data["data"]: 

428 raise ValueError( 

429 "Invalid response format from local embedding service" 

430 ) 

431 return data["data"][0]["embedding"] 

432 except Exception as e: 

433 logger.debug( 

434 "Single embedding request failed", 

435 error=str(e), 

436 error_type=type(e).__name__, 

437 ) 

438 raise # Let the retry logic handle it 

439 

440 def count_tokens(self, text: str) -> int: 

441 """Count the number of tokens in a text string.""" 

442 if self.encoding is None: 

443 # Fallback to character count if no tokenizer is available 

444 return len(text) 

445 return len(self.encoding.encode(text)) 

446 

447 def count_tokens_batch(self, texts: list[str]) -> list[int]: 

448 """Count the number of tokens in a list of text strings.""" 

449 return [self.count_tokens(text) for text in texts] 

450 

451 def get_embedding_dimension(self) -> int: 

452 """Get the dimension of the embedding vectors.""" 

453 return self.settings.global_config.embedding.vector_size or 1536