Coverage for src/qdrant_loader/core/pipeline/workers/embedding_worker.py: 91%

82 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-04 05:50 +0000

1"""Embedding worker for processing chunks into embeddings.""" 

2 

3import asyncio 

4import gc 

5import psutil 

6from collections.abc import AsyncIterator 

7from typing import Any 

8 

9from qdrant_loader.core.embedding.embedding_service import EmbeddingService 

10from qdrant_loader.core.monitoring import prometheus_metrics 

11from qdrant_loader.utils.logging import LoggingConfig 

12 

13from .base_worker import BaseWorker 

14 

15logger = LoggingConfig.get_logger(__name__) 

16 

17 

18class EmbeddingWorker(BaseWorker): 

19 """Handles chunk embedding with batching.""" 

20 

21 def __init__( 

22 self, 

23 embedding_service: EmbeddingService, 

24 max_workers: int = 4, 

25 queue_size: int = 1000, 

26 shutdown_event: asyncio.Event | None = None, 

27 ): 

28 super().__init__(max_workers, queue_size) 

29 self.embedding_service = embedding_service 

30 self.shutdown_event = shutdown_event or asyncio.Event() 

31 

32 async def process(self, chunks: list[Any]) -> list[tuple[Any, list[float]]]: 

33 """Process a batch of chunks into embeddings. 

34 

35 Args: 

36 chunks: List of chunks to embed 

37 

38 Returns: 

39 List of (chunk, embedding) tuples 

40 """ 

41 if not chunks: 

42 return [] 

43 

44 try: 

45 logger.debug(f"EmbeddingWorker processing batch of {len(chunks)} items") 

46 

47 # Monitor memory usage 

48 memory_percent = psutil.virtual_memory().percent 

49 if memory_percent > 85: 

50 logger.warning( 

51 f"High memory usage detected: {memory_percent}%. Running garbage collection..." 

52 ) 

53 gc.collect() 

54 

55 with prometheus_metrics.EMBEDDING_DURATION.time(): 

56 # Add timeout to prevent hanging and check for shutdown 

57 embeddings = await asyncio.wait_for( 

58 self.embedding_service.get_embeddings([c.content for c in chunks]), 

59 timeout=300.0, # Increased to 5 minute timeout for large batches 

60 ) 

61 

62 # Check for shutdown before returning 

63 if self.shutdown_event.is_set(): 

64 logger.debug("EmbeddingWorker skipping result due to shutdown") 

65 return [] 

66 

67 result = list(zip(chunks, embeddings, strict=False)) 

68 logger.debug(f"EmbeddingWorker completed batch of {len(chunks)} items") 

69 

70 # Cleanup after large batches 

71 if len(chunks) > 50: 

72 gc.collect() 

73 

74 return result 

75 

76 except TimeoutError: 

77 logger.error( 

78 f"EmbeddingWorker timed out processing batch of {len(chunks)} items" 

79 ) 

80 raise 

81 except Exception as e: 

82 logger.error(f"EmbeddingWorker error processing batch: {e}") 

83 raise 

84 

85 async def process_chunks( 

86 self, chunks: AsyncIterator[Any] 

87 ) -> AsyncIterator[tuple[Any, list[float]]]: 

88 """Process chunks into embeddings. 

89 

90 Args: 

91 chunks: AsyncIterator of chunks to process 

92 

93 Yields: 

94 (chunk, embedding) tuples 

95 """ 

96 logger.debug("EmbeddingWorker started") 

97 logger.info("🔄 Starting embedding generation...") 

98 batch_size = self.embedding_service.batch_size 

99 batch = [] 

100 total_processed = 0 

101 

102 try: 

103 async for chunk in chunks: 

104 if self.shutdown_event.is_set(): 

105 logger.debug("EmbeddingWorker exiting due to shutdown") 

106 break 

107 

108 batch.append(chunk) 

109 

110 # Process batch when it reaches the desired size 

111 if len(batch) >= batch_size: 

112 try: 

113 logger.debug( 

114 f"🔄 Processing embedding batch of {len(batch)} chunks..." 

115 ) 

116 results = await self.process(batch) 

117 total_processed += len(batch) 

118 logger.info( 

119 f"🔗 Generated embeddings: {len(batch)} items in batch, {total_processed} total processed" 

120 ) 

121 

122 for result in results: 

123 yield result 

124 except Exception as e: 

125 logger.error(f"EmbeddingWorker batch processing failed: {e}") 

126 # Mark chunks as failed but continue processing 

127 for chunk in batch: 

128 logger.error(f"Embedding failed for chunk {chunk.id}: {e}") 

129 

130 batch = [] 

131 

132 # Process any remaining chunks in the final batch 

133 if batch and not self.shutdown_event.is_set(): 

134 try: 

135 logger.debug( 

136 f"🔄 Processing final embedding batch of {len(batch)} chunks..." 

137 ) 

138 results = await self.process(batch) 

139 total_processed += len(batch) 

140 logger.info( 

141 f"🔗 Generated embeddings: {len(batch)} items in final batch, {total_processed} total processed" 

142 ) 

143 

144 for result in results: 

145 yield result 

146 except Exception as e: 

147 logger.error(f"EmbeddingWorker final batch processing failed: {e}") 

148 for chunk in batch: 

149 logger.error(f"Embedding failed for chunk {chunk.id}: {e}") 

150 

151 logger.info(f"✅ Embedding completed: {total_processed} chunks processed") 

152 

153 except asyncio.CancelledError: 

154 logger.debug("EmbeddingWorker cancelled") 

155 raise 

156 finally: 

157 logger.debug("EmbeddingWorker exited")