Coverage for src/qdrant_loader/core/embedding/embedding_service.py: 75%
162 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:05 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:05 +0000
1import asyncio
2import logging
3import time
4from collections.abc import Sequence
5from importlib import import_module
7import requests
8import tiktoken
10from qdrant_loader.config import Settings
11from qdrant_loader.core.document import Document
12from qdrant_loader.utils.logging import LoggingConfig
14logger = LoggingConfig.get_logger(__name__)
17class EmbeddingService:
18 """Service for generating embeddings using provider-agnostic API (via core)."""
20 def __init__(self, settings: Settings):
21 """Initialize the embedding service.
23 Args:
24 settings: The application settings containing API key and endpoint.
25 """
26 self.settings = settings
27 # Build LLM settings from global config and create provider
28 llm_settings = settings.llm_settings
29 factory_mod = import_module("qdrant_loader_core.llm.factory")
30 create_provider = factory_mod.create_provider
31 self.provider = create_provider(llm_settings)
32 self.model = llm_settings.models.get(
33 "embeddings", settings.global_config.embedding.model
34 )
35 self.tokenizer = (
36 llm_settings.tokenizer or settings.global_config.embedding.tokenizer
37 )
38 self.batch_size = settings.global_config.embedding.batch_size
40 # Initialize tokenizer based on configuration
41 if self.tokenizer == "none":
42 self.encoding = None
43 else:
44 try:
45 self.encoding = tiktoken.get_encoding(self.tokenizer)
46 except Exception as e:
47 logger.warning(
48 "Failed to initialize tokenizer, falling back to simple character counting",
49 error=str(e),
50 tokenizer=self.tokenizer,
51 )
52 self.encoding = None
54 self.last_request_time = 0
55 self.min_request_interval = 0.5 # 500ms between requests
57 # Retry configuration for network resilience
58 self.max_retries = 3
59 self.base_retry_delay = 1.0 # Start with 1 second
60 self.max_retry_delay = 30.0 # Cap at 30 seconds
62 async def _apply_rate_limit(self):
63 """Apply rate limiting between API requests."""
64 current_time = time.time()
65 time_since_last_request = current_time - self.last_request_time
66 if time_since_last_request < self.min_request_interval:
67 await asyncio.sleep(self.min_request_interval - time_since_last_request)
68 self.last_request_time = time.time()
70 async def _retry_with_backoff(self, operation, operation_name: str, **kwargs):
71 """Execute an operation with exponential backoff retry logic.
73 Args:
74 operation: The async operation to retry
75 operation_name: Name of the operation for logging
76 **kwargs: Additional arguments passed to the operation
78 Returns:
79 The result of the successful operation
81 Raises:
82 The last exception if all retries fail
83 """
84 last_exception = None
86 for attempt in range(self.max_retries + 1): # +1 for initial attempt
87 try:
88 if attempt > 0:
89 # Calculate exponential backoff delay
90 delay = min(
91 self.base_retry_delay * (2 ** (attempt - 1)),
92 self.max_retry_delay,
93 )
94 logger.warning(
95 f"Retrying {operation_name} after network error",
96 attempt=attempt,
97 max_retries=self.max_retries,
98 delay_seconds=delay,
99 last_error=str(last_exception) if last_exception else None,
100 )
101 await asyncio.sleep(delay)
103 # Execute the operation
104 result = await operation(**kwargs)
106 if attempt > 0:
107 logger.info(
108 f"Successfully recovered {operation_name} after retries",
109 successful_attempt=attempt + 1,
110 total_attempts=attempt + 1,
111 )
113 return result
115 except (
116 TimeoutError,
117 requests.exceptions.Timeout,
118 requests.exceptions.ConnectionError,
119 requests.exceptions.HTTPError,
120 ConnectionError,
121 OSError,
122 ) as e:
123 last_exception = e
125 if attempt == self.max_retries:
126 logger.error(
127 f"All retry attempts failed for {operation_name}",
128 total_attempts=attempt + 1,
129 final_error=str(e),
130 error_type=type(e).__name__,
131 )
132 raise
134 logger.warning(
135 f"Network error in {operation_name}, will retry",
136 attempt=attempt + 1,
137 max_retries=self.max_retries,
138 error=str(e),
139 error_type=type(e).__name__,
140 )
142 except Exception as e:
143 # For non-network errors, don't retry
144 logger.error(
145 f"Non-retryable error in {operation_name}",
146 error=str(e),
147 error_type=type(e).__name__,
148 )
149 raise
151 # This should never be reached, but just in case
152 if last_exception:
153 raise last_exception
154 raise RuntimeError(f"Unexpected error in retry logic for {operation_name}")
156 async def get_embeddings(
157 self, texts: Sequence[str | Document]
158 ) -> list[list[float]]:
159 """Get embeddings for a list of texts."""
160 if not texts:
161 return []
163 # Extract content if texts are Document objects
164 contents = [
165 text.content if isinstance(text, Document) else text for text in texts
166 ]
168 # Filter out empty, None, or invalid content
169 valid_contents = []
170 valid_indices = []
171 for i, content in enumerate(contents):
172 if content and isinstance(content, str) and content.strip():
173 valid_contents.append(content.strip())
174 valid_indices.append(i)
175 else:
176 logger.warning(
177 f"Skipping invalid content at index {i}: {repr(content)}"
178 )
180 if not valid_contents:
181 logger.warning(
182 "No valid content found in batch, returning empty embeddings"
183 )
184 return []
186 logger.debug(
187 "Starting batch embedding process",
188 total_texts=len(contents),
189 valid_texts=len(valid_contents),
190 filtered_out=len(contents) - len(valid_contents),
191 )
193 # Validate and split content based on token limits
194 # Use configurable token limits from settings
195 MAX_TOKENS_PER_REQUEST = (
196 self.settings.global_config.embedding.max_tokens_per_request
197 )
198 MAX_TOKENS_PER_CHUNK = (
199 self.settings.global_config.embedding.max_tokens_per_chunk
200 )
202 validated_contents = []
203 truncated_count = 0
204 for content in valid_contents:
205 token_count = self.count_tokens(content)
206 if token_count > MAX_TOKENS_PER_CHUNK:
207 truncated_count += 1
208 logger.warning(
209 "Content exceeds maximum token limit, truncating",
210 content_length=len(content),
211 token_count=token_count,
212 max_tokens=MAX_TOKENS_PER_CHUNK,
213 )
214 # Truncate content to fit within token limit
215 if self.encoding is not None:
216 # Use tokenizer to truncate precisely
217 tokens = self.encoding.encode(content)
218 truncated_tokens = tokens[:MAX_TOKENS_PER_CHUNK]
219 truncated_content = self.encoding.decode(truncated_tokens)
220 validated_contents.append(truncated_content)
221 else:
222 # Fallback to character-based truncation (rough estimate)
223 # Assume ~4 characters per token on average
224 max_chars = MAX_TOKENS_PER_CHUNK * 4
225 validated_contents.append(content[:max_chars])
226 else:
227 validated_contents.append(content)
229 if truncated_count > 0:
230 logger.info(
231 f"⚠️ Truncated {truncated_count} content items due to token limits. You might want to adjust chunk size and/or max tokens settings in config.yaml"
232 )
234 # Create smart batches that respect token limits
235 embeddings = []
236 current_batch = []
237 current_batch_tokens = 0
238 batch_count = 0
240 for content in validated_contents:
241 content_tokens = self.count_tokens(content)
243 # Check if adding this content would exceed the token limit
244 if current_batch and (
245 current_batch_tokens + content_tokens > MAX_TOKENS_PER_REQUEST
246 ):
247 # Process current batch
248 batch_count += 1
249 batch_embeddings = await self._process_batch(current_batch)
250 embeddings.extend(batch_embeddings)
252 # Start new batch
253 current_batch = [content]
254 current_batch_tokens = content_tokens
255 else:
256 # Add to current batch
257 current_batch.append(content)
258 current_batch_tokens += content_tokens
260 # Process final batch if it exists
261 if current_batch:
262 batch_count += 1
263 batch_embeddings = await self._process_batch(current_batch)
264 embeddings.extend(batch_embeddings)
266 logger.info(
267 f"🔗 Generated embeddings: {len(embeddings)} items in {batch_count} batches"
268 )
269 return embeddings
271 async def _process_batch(self, batch: list[str]) -> list[list[float]]:
272 """Process a single batch of content for embeddings.
274 Args:
275 batch: List of content strings to embed
277 Returns:
278 List of embedding vectors
279 """
280 if not batch:
281 return []
283 batch_num = getattr(self, "_batch_counter", 0) + 1
284 self._batch_counter = batch_num
286 # Optimized: Only calculate tokens for debug when debug logging is enabled
287 if logging.getLogger().isEnabledFor(logging.DEBUG):
288 logger.debug(
289 "Processing embedding batch",
290 batch_num=batch_num,
291 batch_size=len(batch),
292 total_tokens=sum(self.count_tokens(content) for content in batch),
293 )
295 await self._apply_rate_limit()
297 # Use retry logic for network resilience
298 return await self._retry_with_backoff(
299 self._execute_embedding_request,
300 f"embedding batch {batch_num}",
301 batch=batch,
302 batch_num=batch_num,
303 )
305 async def _execute_embedding_request(
306 self, batch: list[str], batch_num: int
307 ) -> list[list[float]]:
308 """Execute the actual embedding request (used by retry logic).
310 Args:
311 batch: List of content strings to embed
312 batch_num: Batch number for logging
314 Returns:
315 List of embedding vectors
316 """
317 try:
318 # Use core provider for embeddings
319 embeddings_client = self.provider.embeddings()
320 batch_embeddings = await embeddings_client.embed(batch)
322 logger.debug(
323 "Completed batch processing",
324 batch_num=batch_num,
325 processed_embeddings=len(batch_embeddings),
326 )
328 return batch_embeddings
330 except Exception as e:
331 logger.debug(
332 "Embedding request failed",
333 batch_num=batch_num,
334 error=str(e),
335 error_type=type(e).__name__,
336 )
337 raise # Let the retry logic handle it
339 async def get_embedding(self, text: str) -> list[float]:
340 """Get embedding for a single text."""
341 # Validate input
342 if not text or not isinstance(text, str) or not text.strip():
343 logger.warning(f"Invalid text for embedding: {repr(text)}")
344 raise ValueError(
345 "Invalid text for embedding: text must be a non-empty string"
346 )
348 clean_text = text.strip()
350 # Use retry logic for network resilience
351 return await self._retry_with_backoff(
352 self._execute_single_embedding_request, "single embedding", text=clean_text
353 )
355 async def _execute_single_embedding_request(self, text: str) -> list[float]:
356 """Execute a single embedding request (used by retry logic).
358 Args:
359 text: The text to embed
361 Returns:
362 The embedding vector
363 """
364 try:
365 await self._apply_rate_limit()
366 embeddings_client = self.provider.embeddings()
367 vectors = await embeddings_client.embed([text])
368 return vectors[0]
369 except Exception as e:
370 logger.debug(
371 "Single embedding request failed",
372 error=str(e),
373 error_type=type(e).__name__,
374 )
375 raise # Let the retry logic handle it
377 def count_tokens(self, text: str) -> int:
378 """Count the number of tokens in a text string."""
379 if self.encoding is None:
380 # Fallback to character count if no tokenizer is available
381 return len(text)
382 return len(self.encoding.encode(text))
384 def count_tokens_batch(self, texts: list[str]) -> list[int]:
385 """Count the number of tokens in a list of text strings."""
386 return [self.count_tokens(text) for text in texts]
388 def get_embedding_dimension(self) -> int:
389 """Get the dimension of the embedding vectors."""
390 # Prefer vector size from unified settings when available
391 dimension = (
392 self.settings.llm_settings.embeddings.vector_size
393 or self.settings.global_config.embedding.vector_size
394 )
395 if not dimension:
396 logger.warning(
397 "Embedding dimension not set in config; using 1536 (deprecated default). Set global.llm.embeddings.vector_size."
398 )
399 return 1536
400 return int(dimension)