Coverage for src / qdrant_loader / core / pipeline / workers / upsert_worker.py: 100%
115 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:40 +0000
1"""Upsert worker for upserting embedded chunks to Qdrant."""
3import asyncio
4from collections import Counter
5from collections.abc import AsyncIterator
6from typing import Any
8from qdrant_client.http import models
10from qdrant_loader.core.monitoring import prometheus_metrics
11from qdrant_loader.core.qdrant_manager import QdrantManager
12from qdrant_loader.utils.logging import LoggingConfig
14from .base_worker import BaseWorker
16logger = LoggingConfig.get_logger(__name__)
19class PipelineResult:
20 """Result of pipeline processing."""
22 def __init__(self):
23 self.success_count: int = 0
24 self.error_count: int = 0
25 self.successfully_processed_documents: set[str] = set()
26 self.failed_document_ids: set[str] = set()
27 self.errors: list[str] = []
30class UpsertWorker(BaseWorker):
31 """Handles upserting embedded chunks to Qdrant."""
33 def __init__(
34 self,
35 qdrant_manager: QdrantManager,
36 batch_size: int,
37 max_workers: int = 4,
38 queue_size: int = 1000,
39 shutdown_event: asyncio.Event | None = None,
40 ):
41 super().__init__(max_workers, queue_size)
42 self.qdrant_manager = qdrant_manager
43 self.batch_size = batch_size
44 self.shutdown_event = shutdown_event or asyncio.Event()
46 def _handle_duplicate_chunk_ids(
47 self,
48 batch: list[tuple[Any, list[float]]],
49 batch_chunk_id_counts: Counter,
50 duplicate_chunk_ids: set[str],
51 same_batch_duplicates: set[str],
52 cross_batch_duplicates: set[str],
53 new_chunk_ids: set[str],
54 successful_doc_ids: set[str],
55 result: PipelineResult,
56 errors: list[str],
57 ) -> None:
58 """Handle duplicate chunk IDs and update result/error bookkeeping."""
59 if not duplicate_chunk_ids:
60 return
62 duplicate_doc_ids = set()
63 for chunk, _ in batch:
64 if str(chunk.id) in duplicate_chunk_ids:
65 parent_doc = chunk.metadata.get("parent_document")
66 if parent_doc:
67 duplicate_doc_ids.add(parent_doc.id)
69 successful_doc_ids -= duplicate_doc_ids
70 result.successfully_processed_documents -= duplicate_doc_ids
72 same_batch_duplicate_occurrences = sum(
73 count - 1 for count in batch_chunk_id_counts.values() if count > 1
74 )
75 total_duplicate_impact = len(duplicate_doc_ids)
76 duplicate_chunk_attempts = len(batch) - len(new_chunk_ids)
78 logger.warning(
79 "Detected chunk ID collisions during upsert; existing points will be overwritten",
80 duplicate_count=len(duplicate_chunk_ids),
81 same_batch_duplicate_count=len(same_batch_duplicates),
82 same_batch_duplicate_occurrences=same_batch_duplicate_occurrences,
83 cross_batch_duplicate_count=len(cross_batch_duplicates),
84 affected_documents=total_duplicate_impact,
85 )
86 errors.append(
87 "Detected duplicate chunk IDs during upsert: "
88 f"{len(cross_batch_duplicates)} cross-batch IDs and "
89 f"{same_batch_duplicate_occurrences} same-batch duplicate occurrences "
90 f"across {len(same_batch_duplicates)} IDs affecting {total_duplicate_impact} document(s): "
91 f"{sorted(duplicate_doc_ids)}"
92 )
93 result.error_count += duplicate_chunk_attempts
95 async def process(
96 self, batch: list[tuple[Any, list[float]]]
97 ) -> tuple[int, int, set[str], list[str]]:
98 """Process a batch of embedded chunks.
100 Args:
101 batch: List of (chunk, embedding) tuples
103 Returns:
104 Tuple of (success_count, error_count, successful_doc_ids, errors)
105 """
106 if not batch:
107 return 0, 0, set(), []
109 success_count = 0
110 error_count = 0
111 successful_doc_ids = set()
112 errors = []
114 try:
115 with prometheus_metrics.UPSERT_DURATION.time():
116 points = [
117 models.PointStruct(
118 id=chunk.id,
119 vector=embedding,
120 payload={
121 "content": chunk.content,
122 "contextual_content": chunk.contextual_content,
123 "metadata": {
124 k: v
125 for k, v in chunk.metadata.items()
126 if k != "parent_document"
127 },
128 "source": chunk.source,
129 "source_type": chunk.source_type,
130 "created_at": chunk.created_at.isoformat(),
131 "updated_at": (
132 getattr(
133 chunk, "updated_at", chunk.created_at
134 ).isoformat()
135 if hasattr(chunk, "updated_at")
136 else chunk.created_at.isoformat()
137 ),
138 "title": getattr(
139 chunk, "title", chunk.metadata.get("title", "")
140 ),
141 "url": getattr(chunk, "url", chunk.metadata.get("url", "")),
142 "document_id": chunk.metadata.get(
143 "parent_document_id", chunk.id
144 ),
145 },
146 )
147 for chunk, embedding in batch
148 ]
150 await self.qdrant_manager.upsert_points(points)
151 prometheus_metrics.INGESTED_DOCUMENTS.inc(len(points))
152 success_count = len(points)
154 # Mark parent documents as successfully processed
155 for chunk, _ in batch:
156 parent_doc = chunk.metadata.get("parent_document")
157 if parent_doc:
158 successful_doc_ids.add(parent_doc.id)
160 except Exception as e:
161 for chunk, _ in batch:
162 logger.error(f"Upsert failed for chunk {chunk.id}: {e}")
163 # Mark parent document as failed
164 parent_doc = chunk.metadata.get("parent_document")
165 if parent_doc:
166 successful_doc_ids.discard(parent_doc.id) # Remove if it was added
167 errors.append(f"Upsert failed for chunk {chunk.id}: {e}")
168 error_count = len(batch)
170 return success_count, error_count, successful_doc_ids, errors
172 async def process_embedded_chunks(
173 self, embedded_chunks: AsyncIterator[tuple[Any, list[float]]]
174 ) -> PipelineResult:
175 """Upsert embedded chunks to Qdrant.
177 Args:
178 embedded_chunks: AsyncIterator of (chunk, embedding) tuples
180 Returns:
181 PipelineResult with processing statistics
182 """
183 logger.debug("UpsertWorker started")
184 result = PipelineResult()
185 batch = []
186 seen_chunk_ids: set[str] = set()
188 try:
189 async for chunk_embedding in embedded_chunks:
190 if self.shutdown_event.is_set():
191 logger.debug("UpsertWorker exiting due to shutdown")
192 break
194 batch.append(chunk_embedding)
196 # Process batch when it reaches the desired size
197 if len(batch) >= self.batch_size:
198 batch_chunk_id_list = [str(chunk.id) for chunk, _ in batch]
199 batch_chunk_ids = set(batch_chunk_id_list)
200 batch_chunk_id_counts = Counter(batch_chunk_id_list)
201 success_count, error_count, successful_doc_ids, errors = (
202 await self.process(batch)
203 )
205 if success_count > 0:
206 same_batch_duplicates = {
207 chunk_id
208 for chunk_id, count in batch_chunk_id_counts.items()
209 if count > 1
210 }
211 cross_batch_duplicates = batch_chunk_ids & seen_chunk_ids
212 duplicate_chunk_ids = (
213 cross_batch_duplicates | same_batch_duplicates
214 )
215 new_chunk_ids = (
216 batch_chunk_ids - seen_chunk_ids - same_batch_duplicates
217 )
219 self._handle_duplicate_chunk_ids(
220 batch=batch,
221 batch_chunk_id_counts=batch_chunk_id_counts,
222 duplicate_chunk_ids=duplicate_chunk_ids,
223 same_batch_duplicates=same_batch_duplicates,
224 cross_batch_duplicates=cross_batch_duplicates,
225 new_chunk_ids=new_chunk_ids,
226 successful_doc_ids=successful_doc_ids,
227 result=result,
228 errors=errors,
229 )
231 # Only update seen_chunk_ids with non-duplicate IDs
232 seen_chunk_ids.update(new_chunk_ids)
233 result.success_count += len(new_chunk_ids)
235 result.error_count += error_count
236 result.successfully_processed_documents.update(successful_doc_ids)
237 result.errors.extend(errors)
238 batch = []
240 # Process any remaining chunks in the final batch
241 if batch and not self.shutdown_event.is_set():
242 batch_chunk_id_list = [str(chunk.id) for chunk, _ in batch]
243 batch_chunk_ids = set(batch_chunk_id_list)
244 batch_chunk_id_counts = Counter(batch_chunk_id_list)
245 success_count, error_count, successful_doc_ids, errors = (
246 await self.process(batch)
247 )
249 if success_count > 0:
250 same_batch_duplicates = {
251 chunk_id
252 for chunk_id, count in batch_chunk_id_counts.items()
253 if count > 1
254 }
255 cross_batch_duplicates = batch_chunk_ids & seen_chunk_ids
256 duplicate_chunk_ids = cross_batch_duplicates | same_batch_duplicates
257 new_chunk_ids = (
258 batch_chunk_ids - seen_chunk_ids - same_batch_duplicates
259 )
261 self._handle_duplicate_chunk_ids(
262 batch=batch,
263 batch_chunk_id_counts=batch_chunk_id_counts,
264 duplicate_chunk_ids=duplicate_chunk_ids,
265 same_batch_duplicates=same_batch_duplicates,
266 cross_batch_duplicates=cross_batch_duplicates,
267 new_chunk_ids=new_chunk_ids,
268 successful_doc_ids=successful_doc_ids,
269 result=result,
270 errors=errors,
271 )
273 # Only update seen_chunk_ids with non-duplicate IDs
274 seen_chunk_ids.update(new_chunk_ids)
275 result.success_count += len(new_chunk_ids)
277 result.error_count += error_count
278 result.successfully_processed_documents.update(successful_doc_ids)
279 result.errors.extend(errors)
281 except asyncio.CancelledError:
282 logger.debug("UpsertWorker cancelled")
283 raise
284 finally:
285 logger.debug("UpsertWorker exited")
287 return result