Coverage for src/qdrant_loader/core/pipeline/workers/embedding_worker.py: 91%
82 statements
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-25 11:39 +0000
« prev ^ index » next coverage.py v7.10.0, created at 2025-07-25 11:39 +0000
1"""Embedding worker for processing chunks into embeddings."""
3import asyncio
4import gc
5from collections.abc import AsyncIterator
6from typing import Any
8import psutil
10from qdrant_loader.core.embedding.embedding_service import EmbeddingService
11from qdrant_loader.core.monitoring import prometheus_metrics
12from qdrant_loader.utils.logging import LoggingConfig
14from .base_worker import BaseWorker
16logger = LoggingConfig.get_logger(__name__)
19class EmbeddingWorker(BaseWorker):
20 """Handles chunk embedding with batching."""
22 def __init__(
23 self,
24 embedding_service: EmbeddingService,
25 max_workers: int = 4,
26 queue_size: int = 1000,
27 shutdown_event: asyncio.Event | None = None,
28 ):
29 super().__init__(max_workers, queue_size)
30 self.embedding_service = embedding_service
31 self.shutdown_event = shutdown_event or asyncio.Event()
33 async def process(self, chunks: list[Any]) -> list[tuple[Any, list[float]]]:
34 """Process a batch of chunks into embeddings.
36 Args:
37 chunks: List of chunks to embed
39 Returns:
40 List of (chunk, embedding) tuples
41 """
42 if not chunks:
43 return []
45 try:
46 logger.debug(f"EmbeddingWorker processing batch of {len(chunks)} items")
48 # Monitor memory usage
49 memory_percent = psutil.virtual_memory().percent
50 if memory_percent > 85:
51 logger.warning(
52 f"High memory usage detected: {memory_percent}%. Running garbage collection..."
53 )
54 gc.collect()
56 with prometheus_metrics.EMBEDDING_DURATION.time():
57 # Add timeout to prevent hanging and check for shutdown
58 embeddings = await asyncio.wait_for(
59 self.embedding_service.get_embeddings([c.content for c in chunks]),
60 timeout=300.0, # Increased to 5 minute timeout for large batches
61 )
63 # Check for shutdown before returning
64 if self.shutdown_event.is_set():
65 logger.debug("EmbeddingWorker skipping result due to shutdown")
66 return []
68 result = list(zip(chunks, embeddings, strict=False))
69 logger.debug(f"EmbeddingWorker completed batch of {len(chunks)} items")
71 # Cleanup after large batches
72 if len(chunks) > 50:
73 gc.collect()
75 return result
77 except TimeoutError:
78 logger.error(
79 f"EmbeddingWorker timed out processing batch of {len(chunks)} items"
80 )
81 raise
82 except Exception as e:
83 logger.error(f"EmbeddingWorker error processing batch: {e}")
84 raise
86 async def process_chunks(
87 self, chunks: AsyncIterator[Any]
88 ) -> AsyncIterator[tuple[Any, list[float]]]:
89 """Process chunks into embeddings.
91 Args:
92 chunks: AsyncIterator of chunks to process
94 Yields:
95 (chunk, embedding) tuples
96 """
97 logger.debug("EmbeddingWorker started")
98 logger.info("🔄 Starting embedding generation...")
99 batch_size = self.embedding_service.batch_size
100 batch = []
101 total_processed = 0
103 try:
104 async for chunk in chunks:
105 if self.shutdown_event.is_set():
106 logger.debug("EmbeddingWorker exiting due to shutdown")
107 break
109 batch.append(chunk)
111 # Process batch when it reaches the desired size
112 if len(batch) >= batch_size:
113 try:
114 logger.debug(
115 f"🔄 Processing embedding batch of {len(batch)} chunks..."
116 )
117 results = await self.process(batch)
118 total_processed += len(batch)
119 logger.info(
120 f"🔗 Generated embeddings: {len(batch)} items in batch, {total_processed} total processed"
121 )
123 for result in results:
124 yield result
125 except Exception as e:
126 logger.error(f"EmbeddingWorker batch processing failed: {e}")
127 # Mark chunks as failed but continue processing
128 for chunk in batch:
129 logger.error(f"Embedding failed for chunk {chunk.id}: {e}")
131 batch = []
133 # Process any remaining chunks in the final batch
134 if batch and not self.shutdown_event.is_set():
135 try:
136 logger.debug(
137 f"🔄 Processing final embedding batch of {len(batch)} chunks..."
138 )
139 results = await self.process(batch)
140 total_processed += len(batch)
141 logger.info(
142 f"🔗 Generated embeddings: {len(batch)} items in final batch, {total_processed} total processed"
143 )
145 for result in results:
146 yield result
147 except Exception as e:
148 logger.error(f"EmbeddingWorker final batch processing failed: {e}")
149 for chunk in batch:
150 logger.error(f"Embedding failed for chunk {chunk.id}: {e}")
152 logger.info(f"✅ Embedding completed: {total_processed} chunks processed")
154 except asyncio.CancelledError:
155 logger.debug("EmbeddingWorker cancelled")
156 raise
157 finally:
158 logger.debug("EmbeddingWorker exited")