Coverage for src/qdrant_loader/core/embedding/embedding_service.py: 74%
175 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:50 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:50 +0000
1import asyncio
2import time
3from collections.abc import Sequence
5import requests
6import tiktoken
7from openai import OpenAI
9from qdrant_loader.config import Settings
10from qdrant_loader.core.document import Document
11from qdrant_loader.utils.logging import LoggingConfig
13logger = LoggingConfig.get_logger(__name__)
16class EmbeddingService:
17 """Service for generating embeddings using OpenAI's API or local service."""
19 def __init__(self, settings: Settings):
20 """Initialize the embedding service.
22 Args:
23 settings: The application settings containing API key and endpoint.
24 """
25 self.settings = settings
26 self.endpoint = settings.global_config.embedding.endpoint.rstrip("/")
27 self.model = settings.global_config.embedding.model
28 self.tokenizer = settings.global_config.embedding.tokenizer
29 self.batch_size = settings.global_config.embedding.batch_size
31 # Initialize client based on endpoint
32 if "https://api.openai.com/v1" == self.endpoint:
33 self.client = OpenAI(
34 api_key=settings.global_config.embedding.api_key, base_url=self.endpoint
35 )
36 self.use_openai = True
37 else:
38 self.client = None
39 self.use_openai = False
41 # Initialize tokenizer based on configuration
42 if self.tokenizer == "none":
43 self.encoding = None
44 else:
45 try:
46 self.encoding = tiktoken.get_encoding(self.tokenizer)
47 except Exception as e:
48 logger.warning(
49 "Failed to initialize tokenizer, falling back to simple character counting",
50 error=str(e),
51 tokenizer=self.tokenizer,
52 )
53 self.encoding = None
55 self.last_request_time = 0
56 self.min_request_interval = 0.5 # 500ms between requests
58 # Retry configuration for network resilience
59 self.max_retries = 3
60 self.base_retry_delay = 1.0 # Start with 1 second
61 self.max_retry_delay = 30.0 # Cap at 30 seconds
63 async def _apply_rate_limit(self):
64 """Apply rate limiting between API requests."""
65 current_time = time.time()
66 time_since_last_request = current_time - self.last_request_time
67 if time_since_last_request < self.min_request_interval:
68 await asyncio.sleep(self.min_request_interval - time_since_last_request)
69 self.last_request_time = time.time()
71 async def _retry_with_backoff(self, operation, operation_name: str, **kwargs):
72 """Execute an operation with exponential backoff retry logic.
74 Args:
75 operation: The async operation to retry
76 operation_name: Name of the operation for logging
77 **kwargs: Additional arguments passed to the operation
79 Returns:
80 The result of the successful operation
82 Raises:
83 The last exception if all retries fail
84 """
85 last_exception = None
87 for attempt in range(self.max_retries + 1): # +1 for initial attempt
88 try:
89 if attempt > 0:
90 # Calculate exponential backoff delay
91 delay = min(
92 self.base_retry_delay * (2 ** (attempt - 1)),
93 self.max_retry_delay,
94 )
95 logger.warning(
96 f"Retrying {operation_name} after network error",
97 attempt=attempt,
98 max_retries=self.max_retries,
99 delay_seconds=delay,
100 last_error=str(last_exception) if last_exception else None,
101 )
102 await asyncio.sleep(delay)
104 # Execute the operation
105 result = await operation(**kwargs)
107 if attempt > 0:
108 logger.info(
109 f"Successfully recovered {operation_name} after retries",
110 successful_attempt=attempt + 1,
111 total_attempts=attempt + 1,
112 )
114 return result
116 except (
117 asyncio.TimeoutError,
118 requests.exceptions.Timeout,
119 requests.exceptions.ConnectionError,
120 requests.exceptions.HTTPError,
121 ConnectionError,
122 OSError,
123 ) as e:
124 last_exception = e
126 if attempt == self.max_retries:
127 logger.error(
128 f"All retry attempts failed for {operation_name}",
129 total_attempts=attempt + 1,
130 final_error=str(e),
131 error_type=type(e).__name__,
132 )
133 raise
135 logger.warning(
136 f"Network error in {operation_name}, will retry",
137 attempt=attempt + 1,
138 max_retries=self.max_retries,
139 error=str(e),
140 error_type=type(e).__name__,
141 )
143 except Exception as e:
144 # For non-network errors, don't retry
145 logger.error(
146 f"Non-retryable error in {operation_name}",
147 error=str(e),
148 error_type=type(e).__name__,
149 )
150 raise
152 # This should never be reached, but just in case
153 if last_exception:
154 raise last_exception
155 raise RuntimeError(f"Unexpected error in retry logic for {operation_name}")
157 async def get_embeddings(
158 self, texts: Sequence[str | Document]
159 ) -> list[list[float]]:
160 """Get embeddings for a list of texts."""
161 if not texts:
162 return []
164 # Extract content if texts are Document objects
165 contents = [
166 text.content if isinstance(text, Document) else text for text in texts
167 ]
169 # Filter out empty, None, or invalid content
170 valid_contents = []
171 valid_indices = []
172 for i, content in enumerate(contents):
173 if content and isinstance(content, str) and content.strip():
174 valid_contents.append(content.strip())
175 valid_indices.append(i)
176 else:
177 logger.warning(
178 f"Skipping invalid content at index {i}: {repr(content)}"
179 )
181 if not valid_contents:
182 logger.warning(
183 "No valid content found in batch, returning empty embeddings"
184 )
185 return []
187 logger.debug(
188 "Starting batch embedding process",
189 total_texts=len(contents),
190 valid_texts=len(valid_contents),
191 filtered_out=len(contents) - len(valid_contents),
192 )
194 # Validate and split content based on token limits
195 # Use configurable token limits from settings
196 MAX_TOKENS_PER_REQUEST = (
197 self.settings.global_config.embedding.max_tokens_per_request
198 )
199 MAX_TOKENS_PER_CHUNK = (
200 self.settings.global_config.embedding.max_tokens_per_chunk
201 )
203 validated_contents = []
204 truncated_count = 0
205 for content in valid_contents:
206 token_count = self.count_tokens(content)
207 if token_count > MAX_TOKENS_PER_CHUNK:
208 truncated_count += 1
209 logger.warning(
210 "Content exceeds maximum token limit, truncating",
211 content_length=len(content),
212 token_count=token_count,
213 max_tokens=MAX_TOKENS_PER_CHUNK,
214 )
215 # Truncate content to fit within token limit
216 if self.encoding is not None:
217 # Use tokenizer to truncate precisely
218 tokens = self.encoding.encode(content)
219 truncated_tokens = tokens[:MAX_TOKENS_PER_CHUNK]
220 truncated_content = self.encoding.decode(truncated_tokens)
221 validated_contents.append(truncated_content)
222 else:
223 # Fallback to character-based truncation (rough estimate)
224 # Assume ~4 characters per token on average
225 max_chars = MAX_TOKENS_PER_CHUNK * 4
226 validated_contents.append(content[:max_chars])
227 else:
228 validated_contents.append(content)
230 if truncated_count > 0:
231 logger.info(
232 f"⚠️ Truncated {truncated_count} content items due to token limits. You might want to adjust chunk size and/or max tokens settings in config.yaml"
233 )
235 # Create smart batches that respect token limits
236 embeddings = []
237 current_batch = []
238 current_batch_tokens = 0
239 batch_count = 0
241 for content in validated_contents:
242 content_tokens = self.count_tokens(content)
244 # Check if adding this content would exceed the token limit
245 if current_batch and (
246 current_batch_tokens + content_tokens > MAX_TOKENS_PER_REQUEST
247 ):
248 # Process current batch
249 batch_count += 1
250 batch_embeddings = await self._process_batch(current_batch)
251 embeddings.extend(batch_embeddings)
253 # Start new batch
254 current_batch = [content]
255 current_batch_tokens = content_tokens
256 else:
257 # Add to current batch
258 current_batch.append(content)
259 current_batch_tokens += content_tokens
261 # Process final batch if it exists
262 if current_batch:
263 batch_count += 1
264 batch_embeddings = await self._process_batch(current_batch)
265 embeddings.extend(batch_embeddings)
267 logger.info(
268 f"🔗 Generated embeddings: {len(embeddings)} items in {batch_count} batches"
269 )
270 return embeddings
272 async def _process_batch(self, batch: list[str]) -> list[list[float]]:
273 """Process a single batch of content for embeddings.
275 Args:
276 batch: List of content strings to embed
278 Returns:
279 List of embedding vectors
280 """
281 if not batch:
282 return []
284 batch_num = getattr(self, "_batch_counter", 0) + 1
285 setattr(self, "_batch_counter", batch_num)
287 logger.debug(
288 "Processing embedding batch",
289 batch_num=batch_num,
290 batch_size=len(batch),
291 total_tokens=sum(self.count_tokens(content) for content in batch),
292 )
294 await self._apply_rate_limit()
296 # Use retry logic for network resilience
297 return await self._retry_with_backoff(
298 self._execute_embedding_request,
299 f"embedding batch {batch_num}",
300 batch=batch,
301 batch_num=batch_num,
302 )
304 async def _execute_embedding_request(
305 self, batch: list[str], batch_num: int
306 ) -> list[list[float]]:
307 """Execute the actual embedding request (used by retry logic).
309 Args:
310 batch: List of content strings to embed
311 batch_num: Batch number for logging
313 Returns:
314 List of embedding vectors
315 """
316 try:
317 if self.use_openai and self.client is not None:
318 logger.debug(
319 "Getting batch embeddings from OpenAI",
320 model=self.model,
321 batch_num=batch_num,
322 )
324 # Use shorter timeout for initial attempts, let retry logic handle failures
325 response = await asyncio.wait_for(
326 asyncio.to_thread(
327 self.client.embeddings.create, model=self.model, input=batch
328 ),
329 timeout=45.0, # Reduced from 90s for faster failure detection
330 )
331 batch_embeddings = [embedding.embedding for embedding in response.data]
333 else:
334 # Local service request
335 logger.debug(
336 "Getting batch embeddings from local service",
337 model=self.model,
338 endpoint=self.endpoint,
339 batch_num=batch_num,
340 )
342 response = await asyncio.wait_for(
343 asyncio.to_thread(
344 requests.post,
345 f"{self.endpoint}/embeddings",
346 json={"input": batch, "model": self.model},
347 headers={"Content-Type": "application/json"},
348 timeout=30, # Reduced timeout for faster failure detection
349 ),
350 timeout=45.0, # Reduced from 90s
351 )
352 response.raise_for_status()
353 data = response.json()
354 if "data" not in data or not data["data"]:
355 raise ValueError(
356 "Invalid response format from local embedding service"
357 )
358 batch_embeddings = [item["embedding"] for item in data["data"]]
360 logger.debug(
361 "Completed batch processing",
362 batch_num=batch_num,
363 processed_embeddings=len(batch_embeddings),
364 )
366 return batch_embeddings
368 except Exception as e:
369 logger.debug(
370 "Embedding request failed",
371 batch_num=batch_num,
372 error=str(e),
373 error_type=type(e).__name__,
374 )
375 raise # Let the retry logic handle it
377 async def get_embedding(self, text: str) -> list[float]:
378 """Get embedding for a single text."""
379 # Validate input
380 if not text or not isinstance(text, str) or not text.strip():
381 logger.warning(f"Invalid text for embedding: {repr(text)}")
382 raise ValueError(
383 "Invalid text for embedding: text must be a non-empty string"
384 )
386 clean_text = text.strip()
388 # Use retry logic for network resilience
389 return await self._retry_with_backoff(
390 self._execute_single_embedding_request, "single embedding", text=clean_text
391 )
393 async def _execute_single_embedding_request(self, text: str) -> list[float]:
394 """Execute a single embedding request (used by retry logic).
396 Args:
397 text: The text to embed
399 Returns:
400 The embedding vector
401 """
402 try:
403 await self._apply_rate_limit()
404 if self.use_openai and self.client is not None:
405 logger.debug("Getting embedding from OpenAI", model=self.model)
406 response = await asyncio.wait_for(
407 asyncio.to_thread(
408 self.client.embeddings.create,
409 model=self.model,
410 input=[text], # OpenAI API expects a list
411 ),
412 timeout=30.0, # Reduced timeout for faster failure detection
413 )
414 return response.data[0].embedding
415 else:
416 # Local service request
417 logger.debug(
418 "Getting embedding from local service",
419 model=self.model,
420 endpoint=self.endpoint,
421 )
422 response = await asyncio.wait_for(
423 asyncio.to_thread(
424 requests.post,
425 f"{self.endpoint}/embeddings",
426 json={"input": text, "model": self.model},
427 headers={"Content-Type": "application/json"},
428 timeout=15, # Reduced timeout
429 ),
430 timeout=30.0, # Reduced timeout
431 )
432 response.raise_for_status()
433 data = response.json()
434 if "data" not in data or not data["data"]:
435 raise ValueError(
436 "Invalid response format from local embedding service"
437 )
438 return data["data"][0]["embedding"]
439 except Exception as e:
440 logger.debug(
441 "Single embedding request failed",
442 error=str(e),
443 error_type=type(e).__name__,
444 )
445 raise # Let the retry logic handle it
447 def count_tokens(self, text: str) -> int:
448 """Count the number of tokens in a text string."""
449 if self.encoding is None:
450 # Fallback to character count if no tokenizer is available
451 return len(text)
452 return len(self.encoding.encode(text))
454 def count_tokens_batch(self, texts: list[str]) -> list[int]:
455 """Count the number of tokens in a list of text strings."""
456 return [self.count_tokens(text) for text in texts]
458 def get_embedding_dimension(self) -> int:
459 """Get the dimension of the embedding vectors."""
460 return self.settings.global_config.embedding.vector_size or 1536