Coverage for src/qdrant_loader/core/embedding/embedding_service.py: 74%

175 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-04 05:50 +0000

1import asyncio 

2import time 

3from collections.abc import Sequence 

4 

5import requests 

6import tiktoken 

7from openai import OpenAI 

8 

9from qdrant_loader.config import Settings 

10from qdrant_loader.core.document import Document 

11from qdrant_loader.utils.logging import LoggingConfig 

12 

13logger = LoggingConfig.get_logger(__name__) 

14 

15 

16class EmbeddingService: 

17 """Service for generating embeddings using OpenAI's API or local service.""" 

18 

19 def __init__(self, settings: Settings): 

20 """Initialize the embedding service. 

21 

22 Args: 

23 settings: The application settings containing API key and endpoint. 

24 """ 

25 self.settings = settings 

26 self.endpoint = settings.global_config.embedding.endpoint.rstrip("/") 

27 self.model = settings.global_config.embedding.model 

28 self.tokenizer = settings.global_config.embedding.tokenizer 

29 self.batch_size = settings.global_config.embedding.batch_size 

30 

31 # Initialize client based on endpoint 

32 if "https://api.openai.com/v1" == self.endpoint: 

33 self.client = OpenAI( 

34 api_key=settings.global_config.embedding.api_key, base_url=self.endpoint 

35 ) 

36 self.use_openai = True 

37 else: 

38 self.client = None 

39 self.use_openai = False 

40 

41 # Initialize tokenizer based on configuration 

42 if self.tokenizer == "none": 

43 self.encoding = None 

44 else: 

45 try: 

46 self.encoding = tiktoken.get_encoding(self.tokenizer) 

47 except Exception as e: 

48 logger.warning( 

49 "Failed to initialize tokenizer, falling back to simple character counting", 

50 error=str(e), 

51 tokenizer=self.tokenizer, 

52 ) 

53 self.encoding = None 

54 

55 self.last_request_time = 0 

56 self.min_request_interval = 0.5 # 500ms between requests 

57 

58 # Retry configuration for network resilience 

59 self.max_retries = 3 

60 self.base_retry_delay = 1.0 # Start with 1 second 

61 self.max_retry_delay = 30.0 # Cap at 30 seconds 

62 

63 async def _apply_rate_limit(self): 

64 """Apply rate limiting between API requests.""" 

65 current_time = time.time() 

66 time_since_last_request = current_time - self.last_request_time 

67 if time_since_last_request < self.min_request_interval: 

68 await asyncio.sleep(self.min_request_interval - time_since_last_request) 

69 self.last_request_time = time.time() 

70 

71 async def _retry_with_backoff(self, operation, operation_name: str, **kwargs): 

72 """Execute an operation with exponential backoff retry logic. 

73 

74 Args: 

75 operation: The async operation to retry 

76 operation_name: Name of the operation for logging 

77 **kwargs: Additional arguments passed to the operation 

78 

79 Returns: 

80 The result of the successful operation 

81 

82 Raises: 

83 The last exception if all retries fail 

84 """ 

85 last_exception = None 

86 

87 for attempt in range(self.max_retries + 1): # +1 for initial attempt 

88 try: 

89 if attempt > 0: 

90 # Calculate exponential backoff delay 

91 delay = min( 

92 self.base_retry_delay * (2 ** (attempt - 1)), 

93 self.max_retry_delay, 

94 ) 

95 logger.warning( 

96 f"Retrying {operation_name} after network error", 

97 attempt=attempt, 

98 max_retries=self.max_retries, 

99 delay_seconds=delay, 

100 last_error=str(last_exception) if last_exception else None, 

101 ) 

102 await asyncio.sleep(delay) 

103 

104 # Execute the operation 

105 result = await operation(**kwargs) 

106 

107 if attempt > 0: 

108 logger.info( 

109 f"Successfully recovered {operation_name} after retries", 

110 successful_attempt=attempt + 1, 

111 total_attempts=attempt + 1, 

112 ) 

113 

114 return result 

115 

116 except ( 

117 asyncio.TimeoutError, 

118 requests.exceptions.Timeout, 

119 requests.exceptions.ConnectionError, 

120 requests.exceptions.HTTPError, 

121 ConnectionError, 

122 OSError, 

123 ) as e: 

124 last_exception = e 

125 

126 if attempt == self.max_retries: 

127 logger.error( 

128 f"All retry attempts failed for {operation_name}", 

129 total_attempts=attempt + 1, 

130 final_error=str(e), 

131 error_type=type(e).__name__, 

132 ) 

133 raise 

134 

135 logger.warning( 

136 f"Network error in {operation_name}, will retry", 

137 attempt=attempt + 1, 

138 max_retries=self.max_retries, 

139 error=str(e), 

140 error_type=type(e).__name__, 

141 ) 

142 

143 except Exception as e: 

144 # For non-network errors, don't retry 

145 logger.error( 

146 f"Non-retryable error in {operation_name}", 

147 error=str(e), 

148 error_type=type(e).__name__, 

149 ) 

150 raise 

151 

152 # This should never be reached, but just in case 

153 if last_exception: 

154 raise last_exception 

155 raise RuntimeError(f"Unexpected error in retry logic for {operation_name}") 

156 

157 async def get_embeddings( 

158 self, texts: Sequence[str | Document] 

159 ) -> list[list[float]]: 

160 """Get embeddings for a list of texts.""" 

161 if not texts: 

162 return [] 

163 

164 # Extract content if texts are Document objects 

165 contents = [ 

166 text.content if isinstance(text, Document) else text for text in texts 

167 ] 

168 

169 # Filter out empty, None, or invalid content 

170 valid_contents = [] 

171 valid_indices = [] 

172 for i, content in enumerate(contents): 

173 if content and isinstance(content, str) and content.strip(): 

174 valid_contents.append(content.strip()) 

175 valid_indices.append(i) 

176 else: 

177 logger.warning( 

178 f"Skipping invalid content at index {i}: {repr(content)}" 

179 ) 

180 

181 if not valid_contents: 

182 logger.warning( 

183 "No valid content found in batch, returning empty embeddings" 

184 ) 

185 return [] 

186 

187 logger.debug( 

188 "Starting batch embedding process", 

189 total_texts=len(contents), 

190 valid_texts=len(valid_contents), 

191 filtered_out=len(contents) - len(valid_contents), 

192 ) 

193 

194 # Validate and split content based on token limits 

195 # Use configurable token limits from settings 

196 MAX_TOKENS_PER_REQUEST = ( 

197 self.settings.global_config.embedding.max_tokens_per_request 

198 ) 

199 MAX_TOKENS_PER_CHUNK = ( 

200 self.settings.global_config.embedding.max_tokens_per_chunk 

201 ) 

202 

203 validated_contents = [] 

204 truncated_count = 0 

205 for content in valid_contents: 

206 token_count = self.count_tokens(content) 

207 if token_count > MAX_TOKENS_PER_CHUNK: 

208 truncated_count += 1 

209 logger.warning( 

210 "Content exceeds maximum token limit, truncating", 

211 content_length=len(content), 

212 token_count=token_count, 

213 max_tokens=MAX_TOKENS_PER_CHUNK, 

214 ) 

215 # Truncate content to fit within token limit 

216 if self.encoding is not None: 

217 # Use tokenizer to truncate precisely 

218 tokens = self.encoding.encode(content) 

219 truncated_tokens = tokens[:MAX_TOKENS_PER_CHUNK] 

220 truncated_content = self.encoding.decode(truncated_tokens) 

221 validated_contents.append(truncated_content) 

222 else: 

223 # Fallback to character-based truncation (rough estimate) 

224 # Assume ~4 characters per token on average 

225 max_chars = MAX_TOKENS_PER_CHUNK * 4 

226 validated_contents.append(content[:max_chars]) 

227 else: 

228 validated_contents.append(content) 

229 

230 if truncated_count > 0: 

231 logger.info( 

232 f"⚠️ Truncated {truncated_count} content items due to token limits. You might want to adjust chunk size and/or max tokens settings in config.yaml" 

233 ) 

234 

235 # Create smart batches that respect token limits 

236 embeddings = [] 

237 current_batch = [] 

238 current_batch_tokens = 0 

239 batch_count = 0 

240 

241 for content in validated_contents: 

242 content_tokens = self.count_tokens(content) 

243 

244 # Check if adding this content would exceed the token limit 

245 if current_batch and ( 

246 current_batch_tokens + content_tokens > MAX_TOKENS_PER_REQUEST 

247 ): 

248 # Process current batch 

249 batch_count += 1 

250 batch_embeddings = await self._process_batch(current_batch) 

251 embeddings.extend(batch_embeddings) 

252 

253 # Start new batch 

254 current_batch = [content] 

255 current_batch_tokens = content_tokens 

256 else: 

257 # Add to current batch 

258 current_batch.append(content) 

259 current_batch_tokens += content_tokens 

260 

261 # Process final batch if it exists 

262 if current_batch: 

263 batch_count += 1 

264 batch_embeddings = await self._process_batch(current_batch) 

265 embeddings.extend(batch_embeddings) 

266 

267 logger.info( 

268 f"🔗 Generated embeddings: {len(embeddings)} items in {batch_count} batches" 

269 ) 

270 return embeddings 

271 

272 async def _process_batch(self, batch: list[str]) -> list[list[float]]: 

273 """Process a single batch of content for embeddings. 

274 

275 Args: 

276 batch: List of content strings to embed 

277 

278 Returns: 

279 List of embedding vectors 

280 """ 

281 if not batch: 

282 return [] 

283 

284 batch_num = getattr(self, "_batch_counter", 0) + 1 

285 setattr(self, "_batch_counter", batch_num) 

286 

287 logger.debug( 

288 "Processing embedding batch", 

289 batch_num=batch_num, 

290 batch_size=len(batch), 

291 total_tokens=sum(self.count_tokens(content) for content in batch), 

292 ) 

293 

294 await self._apply_rate_limit() 

295 

296 # Use retry logic for network resilience 

297 return await self._retry_with_backoff( 

298 self._execute_embedding_request, 

299 f"embedding batch {batch_num}", 

300 batch=batch, 

301 batch_num=batch_num, 

302 ) 

303 

304 async def _execute_embedding_request( 

305 self, batch: list[str], batch_num: int 

306 ) -> list[list[float]]: 

307 """Execute the actual embedding request (used by retry logic). 

308 

309 Args: 

310 batch: List of content strings to embed 

311 batch_num: Batch number for logging 

312 

313 Returns: 

314 List of embedding vectors 

315 """ 

316 try: 

317 if self.use_openai and self.client is not None: 

318 logger.debug( 

319 "Getting batch embeddings from OpenAI", 

320 model=self.model, 

321 batch_num=batch_num, 

322 ) 

323 

324 # Use shorter timeout for initial attempts, let retry logic handle failures 

325 response = await asyncio.wait_for( 

326 asyncio.to_thread( 

327 self.client.embeddings.create, model=self.model, input=batch 

328 ), 

329 timeout=45.0, # Reduced from 90s for faster failure detection 

330 ) 

331 batch_embeddings = [embedding.embedding for embedding in response.data] 

332 

333 else: 

334 # Local service request 

335 logger.debug( 

336 "Getting batch embeddings from local service", 

337 model=self.model, 

338 endpoint=self.endpoint, 

339 batch_num=batch_num, 

340 ) 

341 

342 response = await asyncio.wait_for( 

343 asyncio.to_thread( 

344 requests.post, 

345 f"{self.endpoint}/embeddings", 

346 json={"input": batch, "model": self.model}, 

347 headers={"Content-Type": "application/json"}, 

348 timeout=30, # Reduced timeout for faster failure detection 

349 ), 

350 timeout=45.0, # Reduced from 90s 

351 ) 

352 response.raise_for_status() 

353 data = response.json() 

354 if "data" not in data or not data["data"]: 

355 raise ValueError( 

356 "Invalid response format from local embedding service" 

357 ) 

358 batch_embeddings = [item["embedding"] for item in data["data"]] 

359 

360 logger.debug( 

361 "Completed batch processing", 

362 batch_num=batch_num, 

363 processed_embeddings=len(batch_embeddings), 

364 ) 

365 

366 return batch_embeddings 

367 

368 except Exception as e: 

369 logger.debug( 

370 "Embedding request failed", 

371 batch_num=batch_num, 

372 error=str(e), 

373 error_type=type(e).__name__, 

374 ) 

375 raise # Let the retry logic handle it 

376 

377 async def get_embedding(self, text: str) -> list[float]: 

378 """Get embedding for a single text.""" 

379 # Validate input 

380 if not text or not isinstance(text, str) or not text.strip(): 

381 logger.warning(f"Invalid text for embedding: {repr(text)}") 

382 raise ValueError( 

383 "Invalid text for embedding: text must be a non-empty string" 

384 ) 

385 

386 clean_text = text.strip() 

387 

388 # Use retry logic for network resilience 

389 return await self._retry_with_backoff( 

390 self._execute_single_embedding_request, "single embedding", text=clean_text 

391 ) 

392 

393 async def _execute_single_embedding_request(self, text: str) -> list[float]: 

394 """Execute a single embedding request (used by retry logic). 

395 

396 Args: 

397 text: The text to embed 

398 

399 Returns: 

400 The embedding vector 

401 """ 

402 try: 

403 await self._apply_rate_limit() 

404 if self.use_openai and self.client is not None: 

405 logger.debug("Getting embedding from OpenAI", model=self.model) 

406 response = await asyncio.wait_for( 

407 asyncio.to_thread( 

408 self.client.embeddings.create, 

409 model=self.model, 

410 input=[text], # OpenAI API expects a list 

411 ), 

412 timeout=30.0, # Reduced timeout for faster failure detection 

413 ) 

414 return response.data[0].embedding 

415 else: 

416 # Local service request 

417 logger.debug( 

418 "Getting embedding from local service", 

419 model=self.model, 

420 endpoint=self.endpoint, 

421 ) 

422 response = await asyncio.wait_for( 

423 asyncio.to_thread( 

424 requests.post, 

425 f"{self.endpoint}/embeddings", 

426 json={"input": text, "model": self.model}, 

427 headers={"Content-Type": "application/json"}, 

428 timeout=15, # Reduced timeout 

429 ), 

430 timeout=30.0, # Reduced timeout 

431 ) 

432 response.raise_for_status() 

433 data = response.json() 

434 if "data" not in data or not data["data"]: 

435 raise ValueError( 

436 "Invalid response format from local embedding service" 

437 ) 

438 return data["data"][0]["embedding"] 

439 except Exception as e: 

440 logger.debug( 

441 "Single embedding request failed", 

442 error=str(e), 

443 error_type=type(e).__name__, 

444 ) 

445 raise # Let the retry logic handle it 

446 

447 def count_tokens(self, text: str) -> int: 

448 """Count the number of tokens in a text string.""" 

449 if self.encoding is None: 

450 # Fallback to character count if no tokenizer is available 

451 return len(text) 

452 return len(self.encoding.encode(text)) 

453 

454 def count_tokens_batch(self, texts: list[str]) -> list[int]: 

455 """Count the number of tokens in a list of text strings.""" 

456 return [self.count_tokens(text) for text in texts] 

457 

458 def get_embedding_dimension(self) -> int: 

459 """Get the dimension of the embedding vectors.""" 

460 return self.settings.global_config.embedding.vector_size or 1536