Coverage for src / qdrant_loader / core / worker / pool.py: 87%
126 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
1from __future__ import annotations
3import asyncio
4import json
5import time
6from json import JSONDecodeError
7from typing import Any
9from qdrant_loader.core.worker.handlers import JobHandler
10from qdrant_loader.core.worker.queue import JobQueue
11from qdrant_loader.utils.logging import LoggingConfig
13logger = LoggingConfig.get_logger(__name__)
16class QueueWorkerPool:
17 """Run queue jobs with bounded concurrency and per-source serialization."""
19 DEFAULT_SOURCE_KEY = "__global__"
21 def __init__(
22 self,
23 queue: JobQueue,
24 handler: JobHandler,
25 worker_count: int = 4,
26 lease_seconds: int = 60,
27 max_attempts: int = 1,
28 retry_backoff_base_seconds: int = 0,
29 ) -> None:
30 if worker_count < 1:
31 raise ValueError("worker_count must be >= 1")
32 if lease_seconds < 1:
33 raise ValueError("lease_seconds must be >= 1")
34 if max_attempts < 1:
35 raise ValueError("max_attempts must be >= 1")
36 if retry_backoff_base_seconds < 0:
37 raise ValueError("retry_backoff_base_seconds must be >= 0")
39 self._queue = queue
40 self._handler = handler
41 self._worker_count = worker_count
42 self._lease_seconds = lease_seconds
43 self._max_attempts = max_attempts
44 self._retry_backoff_base_seconds = retry_backoff_base_seconds
45 self._source_locks: dict[str, asyncio.Lock] = {}
46 self._source_locks_guard = asyncio.Lock()
47 self._queue_io_guard = asyncio.Lock()
49 async def run_until_empty(self) -> int:
50 """Drain the queue once and return number of attempted jobs."""
51 processed_count = 0
52 processed_count_guard = asyncio.Lock()
54 async def worker() -> None:
55 nonlocal processed_count
57 while True:
58 async with self._queue_io_guard:
59 job = await self._queue.claim_next(
60 lease_seconds=self._lease_seconds
61 )
62 if job is None:
63 return
65 payload, error_message = self._decode_payload(job.payload_json)
66 if error_message is not None:
67 async with self._queue_io_guard:
68 await self._queue.mark_failed(
69 job.id, error_message, claim_attempt=job.attempts
70 )
71 async with processed_count_guard:
72 processed_count += 1
73 continue
75 try:
76 source_key = self._extract_source_key(payload)
77 except KeyError as exc:
78 async with self._queue_io_guard:
79 await self._queue.mark_failed(
80 job.id, str(exc), claim_attempt=job.attempts
81 )
82 async with processed_count_guard:
83 processed_count += 1
84 continue
85 source_lock = await self._get_source_lock(source_key)
87 logger.info(
88 "job.claimed",
89 job_id=job.id,
90 job_type=job.type,
91 source_key=source_key,
92 attempt=job.attempts,
93 )
95 async with source_lock:
96 logger.info(
97 "job.handler_started",
98 job_id=job.id,
99 job_type=job.type,
100 source_key=source_key,
101 attempt=job.attempts,
102 )
103 t0 = time.monotonic()
104 try:
105 # Start lease renewal task to prevent visibility timeout during handler
106 renewal_interval = max(1, self._lease_seconds // 3)
107 lease_lost = asyncio.Event()
109 async def _renew_lease(
110 *,
111 current_job_id: int = job.id,
112 current_claim_attempt: int = job.attempts,
113 current_lease_seconds: int = self._lease_seconds,
114 current_renewal_interval: int = renewal_interval,
115 lease_lost_event: asyncio.Event = lease_lost,
116 ) -> None:
117 while True:
118 await asyncio.sleep(current_renewal_interval)
119 try:
120 async with self._queue_io_guard:
121 extended = await self._queue.extend_visibility(
122 current_job_id,
123 current_lease_seconds,
124 claim_attempt=current_claim_attempt,
125 )
126 if not extended:
127 lease_lost_event.set()
128 return
129 except Exception as exc:
130 logger.warning(
131 "job.lease_renew_failed",
132 job_id=current_job_id,
133 error=str(exc),
134 error_type=type(exc).__name__,
135 )
137 renewal_task = asyncio.create_task(_renew_lease())
138 try:
139 await self._handler(job.type, payload)
140 if lease_lost.is_set():
141 raise RuntimeError(
142 f"Lost lease for job {job.id} while handler was running"
143 )
144 finally:
145 renewal_task.cancel()
146 try:
147 await renewal_task
148 except asyncio.CancelledError:
149 pass
150 except Exception as exc:
151 logger.warning(
152 "job.lease_renew_teardown_failed",
153 job_id=job.id,
154 error=str(exc),
155 error_type=type(exc).__name__,
156 )
157 except Exception as exc:
158 duration_ms = round((time.monotonic() - t0) * 1000)
159 async with self._queue_io_guard:
160 if job.attempts < self._max_attempts:
161 retry_after_seconds = 0
162 if self._retry_backoff_base_seconds > 0:
163 retry_after_seconds = (
164 self._retry_backoff_base_seconds
165 * (2 ** (job.attempts - 1))
166 )
167 await self._queue.release_for_retry(
168 job.id,
169 str(exc),
170 claim_attempt=job.attempts,
171 retry_after_seconds=retry_after_seconds,
172 )
173 logger.info(
174 "job.retry_scheduled",
175 job_id=job.id,
176 job_type=job.type,
177 source_key=source_key,
178 attempt=job.attempts,
179 max_attempts=self._max_attempts,
180 retry_after_seconds=retry_after_seconds,
181 duration_ms=duration_ms,
182 error=str(exc),
183 )
184 else:
185 await self._queue.mark_failed(
186 job.id, str(exc), claim_attempt=job.attempts
187 )
188 logger.info(
189 "job.failed",
190 job_id=job.id,
191 job_type=job.type,
192 source_key=source_key,
193 attempt=job.attempts,
194 max_attempts=self._max_attempts,
195 duration_ms=duration_ms,
196 error=str(exc),
197 )
198 else:
199 duration_ms = round((time.monotonic() - t0) * 1000)
200 async with self._queue_io_guard:
201 await self._queue.mark_done(
202 job.id, claim_attempt=job.attempts
203 )
204 logger.info(
205 "job.done",
206 job_id=job.id,
207 job_type=job.type,
208 source_key=source_key,
209 attempt=job.attempts,
210 duration_ms=duration_ms,
211 )
213 async with processed_count_guard:
214 processed_count += 1
216 await asyncio.gather(*(worker() for _ in range(self._worker_count)))
217 return processed_count
219 async def _get_source_lock(self, source_key: str) -> asyncio.Lock:
220 async with self._source_locks_guard:
221 lock = self._source_locks.get(source_key)
222 if lock is None:
223 lock = asyncio.Lock()
224 self._source_locks[source_key] = lock
225 return lock
227 @staticmethod
228 def _decode_payload(payload_json: str) -> tuple[dict[str, Any], str | None]:
229 try:
230 payload = json.loads(payload_json)
231 except JSONDecodeError as exc:
232 return {}, f"Invalid payload_json: {exc.msg}"
234 if not isinstance(payload, dict):
235 return {}, "Invalid payload_json: expected JSON object"
236 return payload, None
238 @classmethod
239 def _extract_source_key(cls, payload: dict[str, Any]) -> str:
240 """Return the per-source concurrency key from the job payload.
242 Producers *must* set ``source_lock`` to a non-empty string. This is
243 the explicit contract: the pool will never silently fall back to a
244 global key, because that would serialize the entire pool and hide a
245 missing-field bug until production load.
247 Raises:
248 KeyError: if ``source_lock`` is absent or blank.
249 """
250 value = payload.get("source_lock")
251 if isinstance(value, str) and value.strip():
252 return value.strip()
253 raise KeyError(
254 "job payload is missing a non-empty 'source_lock' field — "
255 "all producers must set it explicitly"
256 )