Coverage for src / qdrant_loader / core / worker / pool.py: 87%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:38 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import json 

5import time 

6from json import JSONDecodeError 

7from typing import Any 

8 

9from qdrant_loader.core.worker.handlers import JobHandler 

10from qdrant_loader.core.worker.queue import JobQueue 

11from qdrant_loader.utils.logging import LoggingConfig 

12 

13logger = LoggingConfig.get_logger(__name__) 

14 

15 

16class QueueWorkerPool: 

17 """Run queue jobs with bounded concurrency and per-source serialization.""" 

18 

19 DEFAULT_SOURCE_KEY = "__global__" 

20 

21 def __init__( 

22 self, 

23 queue: JobQueue, 

24 handler: JobHandler, 

25 worker_count: int = 4, 

26 lease_seconds: int = 60, 

27 max_attempts: int = 1, 

28 retry_backoff_base_seconds: int = 0, 

29 ) -> None: 

30 if worker_count < 1: 

31 raise ValueError("worker_count must be >= 1") 

32 if lease_seconds < 1: 

33 raise ValueError("lease_seconds must be >= 1") 

34 if max_attempts < 1: 

35 raise ValueError("max_attempts must be >= 1") 

36 if retry_backoff_base_seconds < 0: 

37 raise ValueError("retry_backoff_base_seconds must be >= 0") 

38 

39 self._queue = queue 

40 self._handler = handler 

41 self._worker_count = worker_count 

42 self._lease_seconds = lease_seconds 

43 self._max_attempts = max_attempts 

44 self._retry_backoff_base_seconds = retry_backoff_base_seconds 

45 self._source_locks: dict[str, asyncio.Lock] = {} 

46 self._source_locks_guard = asyncio.Lock() 

47 self._queue_io_guard = asyncio.Lock() 

48 

49 async def run_until_empty(self) -> int: 

50 """Drain the queue once and return number of attempted jobs.""" 

51 processed_count = 0 

52 processed_count_guard = asyncio.Lock() 

53 

54 async def worker() -> None: 

55 nonlocal processed_count 

56 

57 while True: 

58 async with self._queue_io_guard: 

59 job = await self._queue.claim_next( 

60 lease_seconds=self._lease_seconds 

61 ) 

62 if job is None: 

63 return 

64 

65 payload, error_message = self._decode_payload(job.payload_json) 

66 if error_message is not None: 

67 async with self._queue_io_guard: 

68 await self._queue.mark_failed( 

69 job.id, error_message, claim_attempt=job.attempts 

70 ) 

71 async with processed_count_guard: 

72 processed_count += 1 

73 continue 

74 

75 try: 

76 source_key = self._extract_source_key(payload) 

77 except KeyError as exc: 

78 async with self._queue_io_guard: 

79 await self._queue.mark_failed( 

80 job.id, str(exc), claim_attempt=job.attempts 

81 ) 

82 async with processed_count_guard: 

83 processed_count += 1 

84 continue 

85 source_lock = await self._get_source_lock(source_key) 

86 

87 logger.info( 

88 "job.claimed", 

89 job_id=job.id, 

90 job_type=job.type, 

91 source_key=source_key, 

92 attempt=job.attempts, 

93 ) 

94 

95 async with source_lock: 

96 logger.info( 

97 "job.handler_started", 

98 job_id=job.id, 

99 job_type=job.type, 

100 source_key=source_key, 

101 attempt=job.attempts, 

102 ) 

103 t0 = time.monotonic() 

104 try: 

105 # Start lease renewal task to prevent visibility timeout during handler 

106 renewal_interval = max(1, self._lease_seconds // 3) 

107 lease_lost = asyncio.Event() 

108 

109 async def _renew_lease( 

110 *, 

111 current_job_id: int = job.id, 

112 current_claim_attempt: int = job.attempts, 

113 current_lease_seconds: int = self._lease_seconds, 

114 current_renewal_interval: int = renewal_interval, 

115 lease_lost_event: asyncio.Event = lease_lost, 

116 ) -> None: 

117 while True: 

118 await asyncio.sleep(current_renewal_interval) 

119 try: 

120 async with self._queue_io_guard: 

121 extended = await self._queue.extend_visibility( 

122 current_job_id, 

123 current_lease_seconds, 

124 claim_attempt=current_claim_attempt, 

125 ) 

126 if not extended: 

127 lease_lost_event.set() 

128 return 

129 except Exception as exc: 

130 logger.warning( 

131 "job.lease_renew_failed", 

132 job_id=current_job_id, 

133 error=str(exc), 

134 error_type=type(exc).__name__, 

135 ) 

136 

137 renewal_task = asyncio.create_task(_renew_lease()) 

138 try: 

139 await self._handler(job.type, payload) 

140 if lease_lost.is_set(): 

141 raise RuntimeError( 

142 f"Lost lease for job {job.id} while handler was running" 

143 ) 

144 finally: 

145 renewal_task.cancel() 

146 try: 

147 await renewal_task 

148 except asyncio.CancelledError: 

149 pass 

150 except Exception as exc: 

151 logger.warning( 

152 "job.lease_renew_teardown_failed", 

153 job_id=job.id, 

154 error=str(exc), 

155 error_type=type(exc).__name__, 

156 ) 

157 except Exception as exc: 

158 duration_ms = round((time.monotonic() - t0) * 1000) 

159 async with self._queue_io_guard: 

160 if job.attempts < self._max_attempts: 

161 retry_after_seconds = 0 

162 if self._retry_backoff_base_seconds > 0: 

163 retry_after_seconds = ( 

164 self._retry_backoff_base_seconds 

165 * (2 ** (job.attempts - 1)) 

166 ) 

167 await self._queue.release_for_retry( 

168 job.id, 

169 str(exc), 

170 claim_attempt=job.attempts, 

171 retry_after_seconds=retry_after_seconds, 

172 ) 

173 logger.info( 

174 "job.retry_scheduled", 

175 job_id=job.id, 

176 job_type=job.type, 

177 source_key=source_key, 

178 attempt=job.attempts, 

179 max_attempts=self._max_attempts, 

180 retry_after_seconds=retry_after_seconds, 

181 duration_ms=duration_ms, 

182 error=str(exc), 

183 ) 

184 else: 

185 await self._queue.mark_failed( 

186 job.id, str(exc), claim_attempt=job.attempts 

187 ) 

188 logger.info( 

189 "job.failed", 

190 job_id=job.id, 

191 job_type=job.type, 

192 source_key=source_key, 

193 attempt=job.attempts, 

194 max_attempts=self._max_attempts, 

195 duration_ms=duration_ms, 

196 error=str(exc), 

197 ) 

198 else: 

199 duration_ms = round((time.monotonic() - t0) * 1000) 

200 async with self._queue_io_guard: 

201 await self._queue.mark_done( 

202 job.id, claim_attempt=job.attempts 

203 ) 

204 logger.info( 

205 "job.done", 

206 job_id=job.id, 

207 job_type=job.type, 

208 source_key=source_key, 

209 attempt=job.attempts, 

210 duration_ms=duration_ms, 

211 ) 

212 

213 async with processed_count_guard: 

214 processed_count += 1 

215 

216 await asyncio.gather(*(worker() for _ in range(self._worker_count))) 

217 return processed_count 

218 

219 async def _get_source_lock(self, source_key: str) -> asyncio.Lock: 

220 async with self._source_locks_guard: 

221 lock = self._source_locks.get(source_key) 

222 if lock is None: 

223 lock = asyncio.Lock() 

224 self._source_locks[source_key] = lock 

225 return lock 

226 

227 @staticmethod 

228 def _decode_payload(payload_json: str) -> tuple[dict[str, Any], str | None]: 

229 try: 

230 payload = json.loads(payload_json) 

231 except JSONDecodeError as exc: 

232 return {}, f"Invalid payload_json: {exc.msg}" 

233 

234 if not isinstance(payload, dict): 

235 return {}, "Invalid payload_json: expected JSON object" 

236 return payload, None 

237 

238 @classmethod 

239 def _extract_source_key(cls, payload: dict[str, Any]) -> str: 

240 """Return the per-source concurrency key from the job payload. 

241 

242 Producers *must* set ``source_lock`` to a non-empty string. This is 

243 the explicit contract: the pool will never silently fall back to a 

244 global key, because that would serialize the entire pool and hide a 

245 missing-field bug until production load. 

246 

247 Raises: 

248 KeyError: if ``source_lock`` is absent or blank. 

249 """ 

250 value = payload.get("source_lock") 

251 if isinstance(value, str) and value.strip(): 

252 return value.strip() 

253 raise KeyError( 

254 "job payload is missing a non-empty 'source_lock' field — " 

255 "all producers must set it explicitly" 

256 )