Coverage for src / qdrant_loader / core / worker / handlers.py: 100%
80 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
1"""Job handlers for queue-based ingestion workflows."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from collections.abc import Awaitable, Callable
7from datetime import datetime, timedelta
8from typing import Any, Protocol
10from qdrant_loader.core.state.transitions import get_last_ingestion
11from qdrant_loader.utils.logging import LoggingConfig
13logger = LoggingConfig.get_logger(__name__)
15AsyncSessionFactory = Callable[[], Awaitable[Any]]
18# ---------------------------------------------------------------------------
19# Exceptions (defined first — used by classes below)
20# ---------------------------------------------------------------------------
23class JobHandlerError(Exception):
24 """Base exception for job handler errors."""
26 pass
29class TransientJobError(JobHandlerError):
30 """Transient error (may succeed on retry)."""
32 pass
35class PermanentJobError(JobHandlerError):
36 """Permanent error (will not succeed on retry)."""
38 pass
41# ---------------------------------------------------------------------------
42# Protocol / Abstract base
43# ---------------------------------------------------------------------------
46class JobHandler(Protocol):
47 """Protocol for job handlers invoked by QueueWorkerPool."""
49 async def __call__(self, job_type: str, payload: dict[str, Any]) -> None:
50 """Execute a job of the given type with the provided payload."""
51 ...
54class BaseJobHandler(ABC):
55 """Base class for job handlers."""
57 async def __call__(self, job_type: str, payload: dict[str, Any]) -> None:
58 """Dispatch to handler method based on job_type."""
59 if job_type == "BULK_INGEST":
60 return await self.handle_bulk_ingest(payload)
61 elif job_type == "INCREMENTAL_PULL":
62 return await self.handle_incremental_pull(payload)
63 else:
64 raise PermanentJobError(f"Unknown job type: {job_type}")
66 @abstractmethod
67 async def handle_bulk_ingest(self, payload: dict[str, Any]) -> None:
68 """Handle BULK_INGEST job."""
69 ...
71 @abstractmethod
72 async def handle_incremental_pull(self, payload: dict[str, Any]) -> None:
73 """Handle INCREMENTAL_PULL job."""
74 ...
76 @staticmethod
77 def _calculate_since_timestamp(last_ingestion: datetime | None) -> datetime | None:
78 """Calculate the 'since' timestamp for incremental pulls (last_ingestion - 5min)."""
79 if last_ingestion is None:
80 return None
81 return last_ingestion - timedelta(minutes=5)
84# ---------------------------------------------------------------------------
85# Registry
86# ---------------------------------------------------------------------------
89class HandlerRegistry:
90 """Registry for mapping job types to handler implementations."""
92 def __init__(self) -> None:
93 """Initialize empty registry."""
94 self._handlers: dict[str, BaseJobHandler] = {}
96 def register(self, job_type: str, handler: BaseJobHandler) -> None:
97 """Register a handler for a specific job type."""
98 self._handlers[job_type] = handler
99 logger.debug(
100 "Registered handler", job_type=job_type, handler=handler.__class__.__name__
101 )
103 async def handle(self, job_type: str, payload: dict[str, Any]) -> None:
104 """Execute a job by dispatching to the registered handler."""
105 if job_type not in self._handlers:
106 raise PermanentJobError(f"Unknown job type: {job_type}")
107 handler = self._handlers[job_type]
108 await handler(job_type, payload)
110 def list_handlers(self) -> dict[str, str]:
111 """List all registered handlers."""
112 return {jt: h.__class__.__name__ for jt, h in self._handlers.items()}
115# ---------------------------------------------------------------------------
116# Concrete handler
117# ---------------------------------------------------------------------------
120class IngestionJobHandler(BaseJobHandler):
121 """Concrete handler for BULK_INGEST and INCREMENTAL_PULL jobs.
123 Delegates to PipelineOrchestrator for document processing.
124 """
126 def __init__(
127 self,
128 orchestrator: Any,
129 session_factory: AsyncSessionFactory,
130 ) -> None:
131 self._orchestrator = orchestrator
132 self._session_factory = session_factory
134 @staticmethod
135 def _validated_required_fields(payload: dict[str, Any]) -> tuple[str, str, str]:
136 """Validate required payload fields and return normalized values."""
137 required_fields = ("source_type", "source", "project_id")
138 missing_or_invalid = [
139 field
140 for field in required_fields
141 if not isinstance(payload.get(field), str) or not payload.get(field).strip()
142 ]
143 if missing_or_invalid:
144 fields = ", ".join(missing_or_invalid)
145 raise PermanentJobError(
146 f"Invalid job payload: missing or invalid required field(s): {fields}"
147 )
149 source_type = payload["source_type"].strip()
150 source = payload["source"].strip()
151 project_id = payload["project_id"].strip()
152 return source_type, source, project_id
154 async def handle_bulk_ingest(self, payload: dict[str, Any]) -> None:
155 """Run a full ingestion for the source, bypassing change detection."""
156 try:
157 source_type, source, project_id = self._validated_required_fields(payload)
158 await self._orchestrator.process_documents(
159 source_type=source_type,
160 source=source,
161 project_id=project_id,
162 force=True,
163 )
164 except PermanentJobError:
165 raise
166 except Exception as exc:
167 raise TransientJobError(str(exc)) from exc
169 async def handle_incremental_pull(self, payload: dict[str, Any]) -> None:
170 """Run an incremental ingestion since last_ingestion - 5 min."""
171 try:
172 source_type, source, project_id = self._validated_required_fields(payload)
173 last = await get_last_ingestion(
174 self._session_factory,
175 source_type=source_type,
176 source=source,
177 project_id=project_id,
178 )
179 since = self._calculate_since_timestamp(
180 last.last_successful_ingestion if last else None
181 )
182 logger.info(
183 "incremental_pull.since",
184 source_type=source_type,
185 source=source,
186 since=since.isoformat() if since else None,
187 )
188 await self._orchestrator.process_documents(
189 source_type=source_type,
190 source=source,
191 project_id=project_id,
192 force=False,
193 since=since,
194 )
195 except PermanentJobError:
196 raise
197 except Exception as exc:
198 raise TransientJobError(str(exc)) from exc