Coverage for src/qdrant_loader/core/pipeline/workers/embedding_worker.py: 91%

82 statements  

« prev     ^ index     » next       coverage.py v7.10.0, created at 2025-07-25 11:39 +0000

1"""Embedding worker for processing chunks into embeddings.""" 

2 

3import asyncio 

4import gc 

5from collections.abc import AsyncIterator 

6from typing import Any 

7 

8import psutil 

9 

10from qdrant_loader.core.embedding.embedding_service import EmbeddingService 

11from qdrant_loader.core.monitoring import prometheus_metrics 

12from qdrant_loader.utils.logging import LoggingConfig 

13 

14from .base_worker import BaseWorker 

15 

16logger = LoggingConfig.get_logger(__name__) 

17 

18 

19class EmbeddingWorker(BaseWorker): 

20 """Handles chunk embedding with batching.""" 

21 

22 def __init__( 

23 self, 

24 embedding_service: EmbeddingService, 

25 max_workers: int = 4, 

26 queue_size: int = 1000, 

27 shutdown_event: asyncio.Event | None = None, 

28 ): 

29 super().__init__(max_workers, queue_size) 

30 self.embedding_service = embedding_service 

31 self.shutdown_event = shutdown_event or asyncio.Event() 

32 

33 async def process(self, chunks: list[Any]) -> list[tuple[Any, list[float]]]: 

34 """Process a batch of chunks into embeddings. 

35 

36 Args: 

37 chunks: List of chunks to embed 

38 

39 Returns: 

40 List of (chunk, embedding) tuples 

41 """ 

42 if not chunks: 

43 return [] 

44 

45 try: 

46 logger.debug(f"EmbeddingWorker processing batch of {len(chunks)} items") 

47 

48 # Monitor memory usage 

49 memory_percent = psutil.virtual_memory().percent 

50 if memory_percent > 85: 

51 logger.warning( 

52 f"High memory usage detected: {memory_percent}%. Running garbage collection..." 

53 ) 

54 gc.collect() 

55 

56 with prometheus_metrics.EMBEDDING_DURATION.time(): 

57 # Add timeout to prevent hanging and check for shutdown 

58 embeddings = await asyncio.wait_for( 

59 self.embedding_service.get_embeddings([c.content for c in chunks]), 

60 timeout=300.0, # Increased to 5 minute timeout for large batches 

61 ) 

62 

63 # Check for shutdown before returning 

64 if self.shutdown_event.is_set(): 

65 logger.debug("EmbeddingWorker skipping result due to shutdown") 

66 return [] 

67 

68 result = list(zip(chunks, embeddings, strict=False)) 

69 logger.debug(f"EmbeddingWorker completed batch of {len(chunks)} items") 

70 

71 # Cleanup after large batches 

72 if len(chunks) > 50: 

73 gc.collect() 

74 

75 return result 

76 

77 except TimeoutError: 

78 logger.error( 

79 f"EmbeddingWorker timed out processing batch of {len(chunks)} items" 

80 ) 

81 raise 

82 except Exception as e: 

83 logger.error(f"EmbeddingWorker error processing batch: {e}") 

84 raise 

85 

86 async def process_chunks( 

87 self, chunks: AsyncIterator[Any] 

88 ) -> AsyncIterator[tuple[Any, list[float]]]: 

89 """Process chunks into embeddings. 

90 

91 Args: 

92 chunks: AsyncIterator of chunks to process 

93 

94 Yields: 

95 (chunk, embedding) tuples 

96 """ 

97 logger.debug("EmbeddingWorker started") 

98 logger.info("🔄 Starting embedding generation...") 

99 batch_size = self.embedding_service.batch_size 

100 batch = [] 

101 total_processed = 0 

102 

103 try: 

104 async for chunk in chunks: 

105 if self.shutdown_event.is_set(): 

106 logger.debug("EmbeddingWorker exiting due to shutdown") 

107 break 

108 

109 batch.append(chunk) 

110 

111 # Process batch when it reaches the desired size 

112 if len(batch) >= batch_size: 

113 try: 

114 logger.debug( 

115 f"🔄 Processing embedding batch of {len(batch)} chunks..." 

116 ) 

117 results = await self.process(batch) 

118 total_processed += len(batch) 

119 logger.info( 

120 f"🔗 Generated embeddings: {len(batch)} items in batch, {total_processed} total processed" 

121 ) 

122 

123 for result in results: 

124 yield result 

125 except Exception as e: 

126 logger.error(f"EmbeddingWorker batch processing failed: {e}") 

127 # Mark chunks as failed but continue processing 

128 for chunk in batch: 

129 logger.error(f"Embedding failed for chunk {chunk.id}: {e}") 

130 

131 batch = [] 

132 

133 # Process any remaining chunks in the final batch 

134 if batch and not self.shutdown_event.is_set(): 

135 try: 

136 logger.debug( 

137 f"🔄 Processing final embedding batch of {len(batch)} chunks..." 

138 ) 

139 results = await self.process(batch) 

140 total_processed += len(batch) 

141 logger.info( 

142 f"🔗 Generated embeddings: {len(batch)} items in final batch, {total_processed} total processed" 

143 ) 

144 

145 for result in results: 

146 yield result 

147 except Exception as e: 

148 logger.error(f"EmbeddingWorker final batch processing failed: {e}") 

149 for chunk in batch: 

150 logger.error(f"Embedding failed for chunk {chunk.id}: {e}") 

151 

152 logger.info(f"✅ Embedding completed: {total_processed} chunks processed") 

153 

154 except asyncio.CancelledError: 

155 logger.debug("EmbeddingWorker cancelled") 

156 raise 

157 finally: 

158 logger.debug("EmbeddingWorker exited")