Coverage for src/qdrant_loader/core/pipeline/workers/embedding_worker.py: 91%
82 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:50 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:50 +0000
1"""Embedding worker for processing chunks into embeddings."""
3import asyncio
4import gc
5import psutil
6from collections.abc import AsyncIterator
7from typing import Any
9from qdrant_loader.core.embedding.embedding_service import EmbeddingService
10from qdrant_loader.core.monitoring import prometheus_metrics
11from qdrant_loader.utils.logging import LoggingConfig
13from .base_worker import BaseWorker
15logger = LoggingConfig.get_logger(__name__)
18class EmbeddingWorker(BaseWorker):
19 """Handles chunk embedding with batching."""
21 def __init__(
22 self,
23 embedding_service: EmbeddingService,
24 max_workers: int = 4,
25 queue_size: int = 1000,
26 shutdown_event: asyncio.Event | None = None,
27 ):
28 super().__init__(max_workers, queue_size)
29 self.embedding_service = embedding_service
30 self.shutdown_event = shutdown_event or asyncio.Event()
32 async def process(self, chunks: list[Any]) -> list[tuple[Any, list[float]]]:
33 """Process a batch of chunks into embeddings.
35 Args:
36 chunks: List of chunks to embed
38 Returns:
39 List of (chunk, embedding) tuples
40 """
41 if not chunks:
42 return []
44 try:
45 logger.debug(f"EmbeddingWorker processing batch of {len(chunks)} items")
47 # Monitor memory usage
48 memory_percent = psutil.virtual_memory().percent
49 if memory_percent > 85:
50 logger.warning(
51 f"High memory usage detected: {memory_percent}%. Running garbage collection..."
52 )
53 gc.collect()
55 with prometheus_metrics.EMBEDDING_DURATION.time():
56 # Add timeout to prevent hanging and check for shutdown
57 embeddings = await asyncio.wait_for(
58 self.embedding_service.get_embeddings([c.content for c in chunks]),
59 timeout=300.0, # Increased to 5 minute timeout for large batches
60 )
62 # Check for shutdown before returning
63 if self.shutdown_event.is_set():
64 logger.debug("EmbeddingWorker skipping result due to shutdown")
65 return []
67 result = list(zip(chunks, embeddings, strict=False))
68 logger.debug(f"EmbeddingWorker completed batch of {len(chunks)} items")
70 # Cleanup after large batches
71 if len(chunks) > 50:
72 gc.collect()
74 return result
76 except TimeoutError:
77 logger.error(
78 f"EmbeddingWorker timed out processing batch of {len(chunks)} items"
79 )
80 raise
81 except Exception as e:
82 logger.error(f"EmbeddingWorker error processing batch: {e}")
83 raise
85 async def process_chunks(
86 self, chunks: AsyncIterator[Any]
87 ) -> AsyncIterator[tuple[Any, list[float]]]:
88 """Process chunks into embeddings.
90 Args:
91 chunks: AsyncIterator of chunks to process
93 Yields:
94 (chunk, embedding) tuples
95 """
96 logger.debug("EmbeddingWorker started")
97 logger.info("🔄 Starting embedding generation...")
98 batch_size = self.embedding_service.batch_size
99 batch = []
100 total_processed = 0
102 try:
103 async for chunk in chunks:
104 if self.shutdown_event.is_set():
105 logger.debug("EmbeddingWorker exiting due to shutdown")
106 break
108 batch.append(chunk)
110 # Process batch when it reaches the desired size
111 if len(batch) >= batch_size:
112 try:
113 logger.debug(
114 f"🔄 Processing embedding batch of {len(batch)} chunks..."
115 )
116 results = await self.process(batch)
117 total_processed += len(batch)
118 logger.info(
119 f"🔗 Generated embeddings: {len(batch)} items in batch, {total_processed} total processed"
120 )
122 for result in results:
123 yield result
124 except Exception as e:
125 logger.error(f"EmbeddingWorker batch processing failed: {e}")
126 # Mark chunks as failed but continue processing
127 for chunk in batch:
128 logger.error(f"Embedding failed for chunk {chunk.id}: {e}")
130 batch = []
132 # Process any remaining chunks in the final batch
133 if batch and not self.shutdown_event.is_set():
134 try:
135 logger.debug(
136 f"🔄 Processing final embedding batch of {len(batch)} chunks..."
137 )
138 results = await self.process(batch)
139 total_processed += len(batch)
140 logger.info(
141 f"🔗 Generated embeddings: {len(batch)} items in final batch, {total_processed} total processed"
142 )
144 for result in results:
145 yield result
146 except Exception as e:
147 logger.error(f"EmbeddingWorker final batch processing failed: {e}")
148 for chunk in batch:
149 logger.error(f"Embedding failed for chunk {chunk.id}: {e}")
151 logger.info(f"✅ Embedding completed: {total_processed} chunks processed")
153 except asyncio.CancelledError:
154 logger.debug("EmbeddingWorker cancelled")
155 raise
156 finally:
157 logger.debug("EmbeddingWorker exited")