Coverage for src / qdrant_loader / core / pipeline / workers / embedding_worker.py: 91%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:38 +0000

1"""Embedding worker for processing chunks into embeddings.""" 

2 

3import asyncio 

4import gc 

5from collections.abc import AsyncIterator 

6from typing import Any 

7 

8import psutil 

9 

10from qdrant_loader.core.embedding.embedding_service import EmbeddingService 

11from qdrant_loader.core.monitoring import prometheus_metrics 

12from qdrant_loader.utils.logging import LoggingConfig 

13 

14from .base_worker import BaseWorker 

15 

16logger = LoggingConfig.get_logger(__name__) 

17 

18 

19class EmbeddingWorker(BaseWorker): 

20 """Handles chunk embedding with batching.""" 

21 

22 def __init__( 

23 self, 

24 embedding_service: EmbeddingService, 

25 max_workers: int = 4, 

26 queue_size: int = 1000, 

27 shutdown_event: asyncio.Event | None = None, 

28 ): 

29 super().__init__(max_workers, queue_size) 

30 self.embedding_service = embedding_service 

31 self.shutdown_event = shutdown_event or asyncio.Event() 

32 

33 async def process(self, chunks: list[Any]) -> list[tuple[Any, list[float]]]: 

34 """Process a batch of chunks into embeddings. 

35 

36 Args: 

37 chunks: List of chunks to embed 

38 

39 Returns: 

40 List of (chunk, embedding) tuples 

41 """ 

42 if not chunks: 

43 return [] 

44 

45 try: 

46 logger.debug(f"EmbeddingWorker processing batch of {len(chunks)} items") 

47 

48 # Monitor memory usage 

49 memory_percent = psutil.virtual_memory().percent 

50 if memory_percent > 85: 

51 logger.warning( 

52 f"High memory usage detected: {memory_percent}%. Running garbage collection..." 

53 ) 

54 gc.collect() 

55 

56 with prometheus_metrics.EMBEDDING_DURATION.time(): 

57 # Add timeout to prevent hanging and check for shutdown 

58 embeddings = await asyncio.wait_for( 

59 self.embedding_service.get_embeddings([c.content for c in chunks]), 

60 timeout=300.0, # Increased to 5 minute timeout for large batches 

61 ) 

62 

63 # Check for shutdown before returning 

64 if self.shutdown_event.is_set(): 

65 logger.debug("EmbeddingWorker skipping result due to shutdown") 

66 return [] 

67 

68 # Filter out chunks whose embedding is empty (invalid content was 

69 # skipped in get_embeddings and replaced with [] placeholder). 

70 result = [ 

71 (chunk, emb) 

72 for chunk, emb in zip(chunks, embeddings, strict=False) 

73 if emb 

74 ] 

75 skipped = len(chunks) - len(result) 

76 if skipped: 

77 logger.warning( 

78 f"Skipped {skipped} chunk(s) with empty embeddings, they will not be upserted" 

79 ) 

80 logger.debug(f"EmbeddingWorker completed batch of {len(chunks)} items") 

81 

82 # Cleanup after large batches 

83 if len(chunks) > 50: 

84 gc.collect() 

85 

86 return result 

87 

88 except TimeoutError: 

89 logger.error( 

90 f"EmbeddingWorker timed out processing batch of {len(chunks)} items" 

91 ) 

92 raise 

93 except Exception as e: 

94 logger.error(f"EmbeddingWorker error processing batch: {e}") 

95 raise 

96 

97 async def process_chunks( 

98 self, chunks: AsyncIterator[Any] 

99 ) -> AsyncIterator[tuple[Any, list[float]]]: 

100 """Process chunks into embeddings. 

101 

102 Args: 

103 chunks: AsyncIterator of chunks to process 

104 

105 Yields: 

106 (chunk, embedding) tuples 

107 """ 

108 logger.debug("EmbeddingWorker started") 

109 logger.info("🔄 Starting embedding generation...") 

110 batch_size = self.embedding_service.batch_size 

111 batch = [] 

112 total_processed = 0 

113 

114 try: 

115 async for chunk in chunks: 

116 if self.shutdown_event.is_set(): 

117 logger.debug("EmbeddingWorker exiting due to shutdown") 

118 break 

119 

120 batch.append(chunk) 

121 

122 # Process batch when it reaches the desired size 

123 if len(batch) >= batch_size: 

124 try: 

125 logger.debug( 

126 f"🔄 Processing embedding batch of {len(batch)} chunks..." 

127 ) 

128 results = await self.process(batch) 

129 total_processed += len(results) 

130 logger.info( 

131 f"🔗 Generated embeddings: {len(results)} items in batch, {total_processed} total processed" 

132 ) 

133 

134 for result in results: 

135 yield result 

136 except Exception as e: 

137 logger.error(f"EmbeddingWorker batch processing failed: {e}") 

138 # Mark chunks as failed but continue processing 

139 for chunk in batch: 

140 logger.error(f"Embedding failed for chunk {chunk.id}: {e}") 

141 

142 batch = [] 

143 

144 # Process any remaining chunks in the final batch 

145 if batch and not self.shutdown_event.is_set(): 

146 try: 

147 logger.debug( 

148 f"🔄 Processing final embedding batch of {len(batch)} chunks..." 

149 ) 

150 results = await self.process(batch) 

151 total_processed += len(results) 

152 logger.info( 

153 f"🔗 Generated embeddings: {len(results)} items in final batch, {total_processed} total processed" 

154 ) 

155 

156 for result in results: 

157 yield result 

158 except Exception as e: 

159 logger.error(f"EmbeddingWorker final batch processing failed: {e}") 

160 for chunk in batch: 

161 logger.error(f"Embedding failed for chunk {chunk.id}: {e}") 

162 

163 logger.info(f"✅ Embedding completed: {total_processed} chunks processed") 

164 

165 except asyncio.CancelledError: 

166 logger.debug("EmbeddingWorker cancelled") 

167 raise 

168 finally: 

169 logger.debug("EmbeddingWorker exited")