Coverage for src/qdrant_loader/core/embedding/embedding_service.py: 74%
175 statements
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-25 11:39 +0000
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-25 11:39 +0000
1import asyncio
2import time
3from collections.abc import Sequence
5import requests
6import tiktoken
7from openai import OpenAI
9from qdrant_loader.config import Settings
10from qdrant_loader.core.document import Document
11from qdrant_loader.utils.logging import LoggingConfig
13logger = LoggingConfig.get_logger(__name__)
16class EmbeddingService:
17 """Service for generating embeddings using OpenAI's API or local service."""
19 def __init__(self, settings: Settings):
20 """Initialize the embedding service.
22 Args:
23 settings: The application settings containing API key and endpoint.
24 """
25 self.settings = settings
26 self.endpoint = settings.global_config.embedding.endpoint.rstrip("/")
27 self.model = settings.global_config.embedding.model
28 self.tokenizer = settings.global_config.embedding.tokenizer
29 self.batch_size = settings.global_config.embedding.batch_size
31 # Initialize client based on endpoint
32 if "https://api.openai.com/v1" == self.endpoint:
33 self.client = OpenAI(
34 api_key=settings.global_config.embedding.api_key, base_url=self.endpoint
35 )
36 self.use_openai = True
37 else:
38 self.client = None
39 self.use_openai = False
41 # Initialize tokenizer based on configuration
42 if self.tokenizer == "none":
43 self.encoding = None
44 else:
45 try:
46 self.encoding = tiktoken.get_encoding(self.tokenizer)
47 except Exception as e:
48 logger.warning(
49 "Failed to initialize tokenizer, falling back to simple character counting",
50 error=str(e),
51 tokenizer=self.tokenizer,
52 )
53 self.encoding = None
55 self.last_request_time = 0
56 self.min_request_interval = 0.5 # 500ms between requests
58 # Retry configuration for network resilience
59 self.max_retries = 3
60 self.base_retry_delay = 1.0 # Start with 1 second
61 self.max_retry_delay = 30.0 # Cap at 30 seconds
63 async def _apply_rate_limit(self):
64 """Apply rate limiting between API requests."""
65 current_time = time.time()
66 time_since_last_request = current_time - self.last_request_time
67 if time_since_last_request < self.min_request_interval:
68 await asyncio.sleep(self.min_request_interval - time_since_last_request)
69 self.last_request_time = time.time()
71 async def _retry_with_backoff(self, operation, operation_name: str, **kwargs):
72 """Execute an operation with exponential backoff retry logic.
74 Args:
75 operation: The async operation to retry
76 operation_name: Name of the operation for logging
77 **kwargs: Additional arguments passed to the operation
79 Returns:
80 The result of the successful operation
82 Raises:
83 The last exception if all retries fail
84 """
85 last_exception = None
87 for attempt in range(self.max_retries + 1): # +1 for initial attempt
88 try:
89 if attempt > 0:
90 # Calculate exponential backoff delay
91 delay = min(
92 self.base_retry_delay * (2 ** (attempt - 1)),
93 self.max_retry_delay,
94 )
95 logger.warning(
96 f"Retrying {operation_name} after network error",
97 attempt=attempt,
98 max_retries=self.max_retries,
99 delay_seconds=delay,
100 last_error=str(last_exception) if last_exception else None,
101 )
102 await asyncio.sleep(delay)
104 # Execute the operation
105 result = await operation(**kwargs)
107 if attempt > 0:
108 logger.info(
109 f"Successfully recovered {operation_name} after retries",
110 successful_attempt=attempt + 1,
111 total_attempts=attempt + 1,
112 )
114 return result
116 except (TimeoutError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.HTTPError, ConnectionError, OSError) as e:
117 last_exception = e
119 if attempt == self.max_retries:
120 logger.error(
121 f"All retry attempts failed for {operation_name}",
122 total_attempts=attempt + 1,
123 final_error=str(e),
124 error_type=type(e).__name__,
125 )
126 raise
128 logger.warning(
129 f"Network error in {operation_name}, will retry",
130 attempt=attempt + 1,
131 max_retries=self.max_retries,
132 error=str(e),
133 error_type=type(e).__name__,
134 )
136 except Exception as e:
137 # For non-network errors, don't retry
138 logger.error(
139 f"Non-retryable error in {operation_name}",
140 error=str(e),
141 error_type=type(e).__name__,
142 )
143 raise
145 # This should never be reached, but just in case
146 if last_exception:
147 raise last_exception
148 raise RuntimeError(f"Unexpected error in retry logic for {operation_name}")
150 async def get_embeddings(
151 self, texts: Sequence[str | Document]
152 ) -> list[list[float]]:
153 """Get embeddings for a list of texts."""
154 if not texts:
155 return []
157 # Extract content if texts are Document objects
158 contents = [
159 text.content if isinstance(text, Document) else text for text in texts
160 ]
162 # Filter out empty, None, or invalid content
163 valid_contents = []
164 valid_indices = []
165 for i, content in enumerate(contents):
166 if content and isinstance(content, str) and content.strip():
167 valid_contents.append(content.strip())
168 valid_indices.append(i)
169 else:
170 logger.warning(
171 f"Skipping invalid content at index {i}: {repr(content)}"
172 )
174 if not valid_contents:
175 logger.warning(
176 "No valid content found in batch, returning empty embeddings"
177 )
178 return []
180 logger.debug(
181 "Starting batch embedding process",
182 total_texts=len(contents),
183 valid_texts=len(valid_contents),
184 filtered_out=len(contents) - len(valid_contents),
185 )
187 # Validate and split content based on token limits
188 # Use configurable token limits from settings
189 MAX_TOKENS_PER_REQUEST = (
190 self.settings.global_config.embedding.max_tokens_per_request
191 )
192 MAX_TOKENS_PER_CHUNK = (
193 self.settings.global_config.embedding.max_tokens_per_chunk
194 )
196 validated_contents = []
197 truncated_count = 0
198 for content in valid_contents:
199 token_count = self.count_tokens(content)
200 if token_count > MAX_TOKENS_PER_CHUNK:
201 truncated_count += 1
202 logger.warning(
203 "Content exceeds maximum token limit, truncating",
204 content_length=len(content),
205 token_count=token_count,
206 max_tokens=MAX_TOKENS_PER_CHUNK,
207 )
208 # Truncate content to fit within token limit
209 if self.encoding is not None:
210 # Use tokenizer to truncate precisely
211 tokens = self.encoding.encode(content)
212 truncated_tokens = tokens[:MAX_TOKENS_PER_CHUNK]
213 truncated_content = self.encoding.decode(truncated_tokens)
214 validated_contents.append(truncated_content)
215 else:
216 # Fallback to character-based truncation (rough estimate)
217 # Assume ~4 characters per token on average
218 max_chars = MAX_TOKENS_PER_CHUNK * 4
219 validated_contents.append(content[:max_chars])
220 else:
221 validated_contents.append(content)
223 if truncated_count > 0:
224 logger.info(
225 f"⚠️ Truncated {truncated_count} content items due to token limits. You might want to adjust chunk size and/or max tokens settings in config.yaml"
226 )
228 # Create smart batches that respect token limits
229 embeddings = []
230 current_batch = []
231 current_batch_tokens = 0
232 batch_count = 0
234 for content in validated_contents:
235 content_tokens = self.count_tokens(content)
237 # Check if adding this content would exceed the token limit
238 if current_batch and (
239 current_batch_tokens + content_tokens > MAX_TOKENS_PER_REQUEST
240 ):
241 # Process current batch
242 batch_count += 1
243 batch_embeddings = await self._process_batch(current_batch)
244 embeddings.extend(batch_embeddings)
246 # Start new batch
247 current_batch = [content]
248 current_batch_tokens = content_tokens
249 else:
250 # Add to current batch
251 current_batch.append(content)
252 current_batch_tokens += content_tokens
254 # Process final batch if it exists
255 if current_batch:
256 batch_count += 1
257 batch_embeddings = await self._process_batch(current_batch)
258 embeddings.extend(batch_embeddings)
260 logger.info(
261 f"🔗 Generated embeddings: {len(embeddings)} items in {batch_count} batches"
262 )
263 return embeddings
265 async def _process_batch(self, batch: list[str]) -> list[list[float]]:
266 """Process a single batch of content for embeddings.
268 Args:
269 batch: List of content strings to embed
271 Returns:
272 List of embedding vectors
273 """
274 if not batch:
275 return []
277 batch_num = getattr(self, "_batch_counter", 0) + 1
278 self._batch_counter = batch_num
280 logger.debug(
281 "Processing embedding batch",
282 batch_num=batch_num,
283 batch_size=len(batch),
284 total_tokens=sum(self.count_tokens(content) for content in batch),
285 )
287 await self._apply_rate_limit()
289 # Use retry logic for network resilience
290 return await self._retry_with_backoff(
291 self._execute_embedding_request,
292 f"embedding batch {batch_num}",
293 batch=batch,
294 batch_num=batch_num,
295 )
297 async def _execute_embedding_request(
298 self, batch: list[str], batch_num: int
299 ) -> list[list[float]]:
300 """Execute the actual embedding request (used by retry logic).
302 Args:
303 batch: List of content strings to embed
304 batch_num: Batch number for logging
306 Returns:
307 List of embedding vectors
308 """
309 try:
310 if self.use_openai and self.client is not None:
311 logger.debug(
312 "Getting batch embeddings from OpenAI",
313 model=self.model,
314 batch_num=batch_num,
315 )
317 # Use shorter timeout for initial attempts, let retry logic handle failures
318 response = await asyncio.wait_for(
319 asyncio.to_thread(
320 self.client.embeddings.create, model=self.model, input=batch
321 ),
322 timeout=45.0, # Reduced from 90s for faster failure detection
323 )
324 batch_embeddings = [embedding.embedding for embedding in response.data]
326 else:
327 # Local service request
328 logger.debug(
329 "Getting batch embeddings from local service",
330 model=self.model,
331 endpoint=self.endpoint,
332 batch_num=batch_num,
333 )
335 response = await asyncio.wait_for(
336 asyncio.to_thread(
337 requests.post,
338 f"{self.endpoint}/embeddings",
339 json={"input": batch, "model": self.model},
340 headers={"Content-Type": "application/json"},
341 timeout=30, # Reduced timeout for faster failure detection
342 ),
343 timeout=45.0, # Reduced from 90s
344 )
345 response.raise_for_status()
346 data = response.json()
347 if "data" not in data or not data["data"]:
348 raise ValueError(
349 "Invalid response format from local embedding service"
350 )
351 batch_embeddings = [item["embedding"] for item in data["data"]]
353 logger.debug(
354 "Completed batch processing",
355 batch_num=batch_num,
356 processed_embeddings=len(batch_embeddings),
357 )
359 return batch_embeddings
361 except Exception as e:
362 logger.debug(
363 "Embedding request failed",
364 batch_num=batch_num,
365 error=str(e),
366 error_type=type(e).__name__,
367 )
368 raise # Let the retry logic handle it
370 async def get_embedding(self, text: str) -> list[float]:
371 """Get embedding for a single text."""
372 # Validate input
373 if not text or not isinstance(text, str) or not text.strip():
374 logger.warning(f"Invalid text for embedding: {repr(text)}")
375 raise ValueError(
376 "Invalid text for embedding: text must be a non-empty string"
377 )
379 clean_text = text.strip()
381 # Use retry logic for network resilience
382 return await self._retry_with_backoff(
383 self._execute_single_embedding_request, "single embedding", text=clean_text
384 )
386 async def _execute_single_embedding_request(self, text: str) -> list[float]:
387 """Execute a single embedding request (used by retry logic).
389 Args:
390 text: The text to embed
392 Returns:
393 The embedding vector
394 """
395 try:
396 await self._apply_rate_limit()
397 if self.use_openai and self.client is not None:
398 logger.debug("Getting embedding from OpenAI", model=self.model)
399 response = await asyncio.wait_for(
400 asyncio.to_thread(
401 self.client.embeddings.create,
402 model=self.model,
403 input=[text], # OpenAI API expects a list
404 ),
405 timeout=30.0, # Reduced timeout for faster failure detection
406 )
407 return response.data[0].embedding
408 else:
409 # Local service request
410 logger.debug(
411 "Getting embedding from local service",
412 model=self.model,
413 endpoint=self.endpoint,
414 )
415 response = await asyncio.wait_for(
416 asyncio.to_thread(
417 requests.post,
418 f"{self.endpoint}/embeddings",
419 json={"input": text, "model": self.model},
420 headers={"Content-Type": "application/json"},
421 timeout=15, # Reduced timeout
422 ),
423 timeout=30.0, # Reduced timeout
424 )
425 response.raise_for_status()
426 data = response.json()
427 if "data" not in data or not data["data"]:
428 raise ValueError(
429 "Invalid response format from local embedding service"
430 )
431 return data["data"][0]["embedding"]
432 except Exception as e:
433 logger.debug(
434 "Single embedding request failed",
435 error=str(e),
436 error_type=type(e).__name__,
437 )
438 raise # Let the retry logic handle it
440 def count_tokens(self, text: str) -> int:
441 """Count the number of tokens in a text string."""
442 if self.encoding is None:
443 # Fallback to character count if no tokenizer is available
444 return len(text)
445 return len(self.encoding.encode(text))
447 def count_tokens_batch(self, texts: list[str]) -> list[int]:
448 """Count the number of tokens in a list of text strings."""
449 return [self.count_tokens(text) for text in texts]
451 def get_embedding_dimension(self) -> int:
452 """Get the dimension of the embedding vectors."""
453 return self.settings.global_config.embedding.vector_size or 1536