Coverage for src / qdrant_loader / core / worker / handlers.py: 100%

80 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:38 +0000

1"""Job handlers for queue-based ingestion workflows.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from collections.abc import Awaitable, Callable 

7from datetime import datetime, timedelta 

8from typing import Any, Protocol 

9 

10from qdrant_loader.core.state.transitions import get_last_ingestion 

11from qdrant_loader.utils.logging import LoggingConfig 

12 

13logger = LoggingConfig.get_logger(__name__) 

14 

15AsyncSessionFactory = Callable[[], Awaitable[Any]] 

16 

17 

18# --------------------------------------------------------------------------- 

19# Exceptions (defined first — used by classes below) 

20# --------------------------------------------------------------------------- 

21 

22 

23class JobHandlerError(Exception): 

24 """Base exception for job handler errors.""" 

25 

26 pass 

27 

28 

29class TransientJobError(JobHandlerError): 

30 """Transient error (may succeed on retry).""" 

31 

32 pass 

33 

34 

35class PermanentJobError(JobHandlerError): 

36 """Permanent error (will not succeed on retry).""" 

37 

38 pass 

39 

40 

41# --------------------------------------------------------------------------- 

42# Protocol / Abstract base 

43# --------------------------------------------------------------------------- 

44 

45 

46class JobHandler(Protocol): 

47 """Protocol for job handlers invoked by QueueWorkerPool.""" 

48 

49 async def __call__(self, job_type: str, payload: dict[str, Any]) -> None: 

50 """Execute a job of the given type with the provided payload.""" 

51 ... 

52 

53 

54class BaseJobHandler(ABC): 

55 """Base class for job handlers.""" 

56 

57 async def __call__(self, job_type: str, payload: dict[str, Any]) -> None: 

58 """Dispatch to handler method based on job_type.""" 

59 if job_type == "BULK_INGEST": 

60 return await self.handle_bulk_ingest(payload) 

61 elif job_type == "INCREMENTAL_PULL": 

62 return await self.handle_incremental_pull(payload) 

63 else: 

64 raise PermanentJobError(f"Unknown job type: {job_type}") 

65 

66 @abstractmethod 

67 async def handle_bulk_ingest(self, payload: dict[str, Any]) -> None: 

68 """Handle BULK_INGEST job.""" 

69 ... 

70 

71 @abstractmethod 

72 async def handle_incremental_pull(self, payload: dict[str, Any]) -> None: 

73 """Handle INCREMENTAL_PULL job.""" 

74 ... 

75 

76 @staticmethod 

77 def _calculate_since_timestamp(last_ingestion: datetime | None) -> datetime | None: 

78 """Calculate the 'since' timestamp for incremental pulls (last_ingestion - 5min).""" 

79 if last_ingestion is None: 

80 return None 

81 return last_ingestion - timedelta(minutes=5) 

82 

83 

84# --------------------------------------------------------------------------- 

85# Registry 

86# --------------------------------------------------------------------------- 

87 

88 

89class HandlerRegistry: 

90 """Registry for mapping job types to handler implementations.""" 

91 

92 def __init__(self) -> None: 

93 """Initialize empty registry.""" 

94 self._handlers: dict[str, BaseJobHandler] = {} 

95 

96 def register(self, job_type: str, handler: BaseJobHandler) -> None: 

97 """Register a handler for a specific job type.""" 

98 self._handlers[job_type] = handler 

99 logger.debug( 

100 "Registered handler", job_type=job_type, handler=handler.__class__.__name__ 

101 ) 

102 

103 async def handle(self, job_type: str, payload: dict[str, Any]) -> None: 

104 """Execute a job by dispatching to the registered handler.""" 

105 if job_type not in self._handlers: 

106 raise PermanentJobError(f"Unknown job type: {job_type}") 

107 handler = self._handlers[job_type] 

108 await handler(job_type, payload) 

109 

110 def list_handlers(self) -> dict[str, str]: 

111 """List all registered handlers.""" 

112 return {jt: h.__class__.__name__ for jt, h in self._handlers.items()} 

113 

114 

115# --------------------------------------------------------------------------- 

116# Concrete handler 

117# --------------------------------------------------------------------------- 

118 

119 

120class IngestionJobHandler(BaseJobHandler): 

121 """Concrete handler for BULK_INGEST and INCREMENTAL_PULL jobs. 

122 

123 Delegates to PipelineOrchestrator for document processing. 

124 """ 

125 

126 def __init__( 

127 self, 

128 orchestrator: Any, 

129 session_factory: AsyncSessionFactory, 

130 ) -> None: 

131 self._orchestrator = orchestrator 

132 self._session_factory = session_factory 

133 

134 @staticmethod 

135 def _validated_required_fields(payload: dict[str, Any]) -> tuple[str, str, str]: 

136 """Validate required payload fields and return normalized values.""" 

137 required_fields = ("source_type", "source", "project_id") 

138 missing_or_invalid = [ 

139 field 

140 for field in required_fields 

141 if not isinstance(payload.get(field), str) or not payload.get(field).strip() 

142 ] 

143 if missing_or_invalid: 

144 fields = ", ".join(missing_or_invalid) 

145 raise PermanentJobError( 

146 f"Invalid job payload: missing or invalid required field(s): {fields}" 

147 ) 

148 

149 source_type = payload["source_type"].strip() 

150 source = payload["source"].strip() 

151 project_id = payload["project_id"].strip() 

152 return source_type, source, project_id 

153 

154 async def handle_bulk_ingest(self, payload: dict[str, Any]) -> None: 

155 """Run a full ingestion for the source, bypassing change detection.""" 

156 try: 

157 source_type, source, project_id = self._validated_required_fields(payload) 

158 await self._orchestrator.process_documents( 

159 source_type=source_type, 

160 source=source, 

161 project_id=project_id, 

162 force=True, 

163 ) 

164 except PermanentJobError: 

165 raise 

166 except Exception as exc: 

167 raise TransientJobError(str(exc)) from exc 

168 

169 async def handle_incremental_pull(self, payload: dict[str, Any]) -> None: 

170 """Run an incremental ingestion since last_ingestion - 5 min.""" 

171 try: 

172 source_type, source, project_id = self._validated_required_fields(payload) 

173 last = await get_last_ingestion( 

174 self._session_factory, 

175 source_type=source_type, 

176 source=source, 

177 project_id=project_id, 

178 ) 

179 since = self._calculate_since_timestamp( 

180 last.last_successful_ingestion if last else None 

181 ) 

182 logger.info( 

183 "incremental_pull.since", 

184 source_type=source_type, 

185 source=source, 

186 since=since.isoformat() if since else None, 

187 ) 

188 await self._orchestrator.process_documents( 

189 source_type=source_type, 

190 source=source, 

191 project_id=project_id, 

192 force=False, 

193 since=since, 

194 ) 

195 except PermanentJobError: 

196 raise 

197 except Exception as exc: 

198 raise TransientJobError(str(exc)) from exc