Coverage for src / qdrant_loader / core / qdrant_manager.py: 71%
221 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
1import asyncio
2from typing import Any, cast
3from urllib.parse import urlparse
5from qdrant_client import QdrantClient
6from qdrant_client.http import models
7from qdrant_client.http.models import (
8 Distance,
9 VectorParams,
10)
11from qdrant_loader_core.config import (
12 CollectionVectorCapabilities,
13 SparseRuntimeConfig,
14 parse_collection_capabilities,
15)
16from qdrant_loader_core.sparse import get_sparse_encoder
18from ..config import Settings, get_global_config, get_settings
19from ..utils.logging import LoggingConfig
21logger = LoggingConfig.get_logger(__name__)
24class QdrantConnectionError(Exception):
25 """Custom exception for Qdrant connection errors."""
27 def __init__(
28 self, message: str, original_error: str | None = None, url: str | None = None
29 ):
30 self.message = message
31 self.original_error = original_error
32 self.url = url
33 super().__init__(self.message)
36class QdrantManager:
37 def __init__(self, settings: Settings | None = None):
38 """Initialize the qDrant manager.
40 Args:
41 settings: The application settings
42 """
43 self.settings = settings or get_settings()
44 self.client = None
45 self.collection_name = self.settings.qdrant_collection_name
46 self.logger = LoggingConfig.get_logger(__name__)
47 self.batch_size = get_global_config().embedding.batch_size
48 self.sparse_runtime = self._resolve_sparse_runtime_config()
49 self._collection_vector_capabilities: CollectionVectorCapabilities | None = None
50 self._sparse_fallback_warning_emitted = False
51 self.connect()
53 def _is_api_key_present(self) -> bool:
54 """
55 Check if a valid API key is present.
56 Returns True if the API key is a non-empty string that is not 'None' or 'null'.
57 """
58 api_key = self.settings.qdrant_api_key
59 if not api_key: # Catches None, empty string, etc.
60 return False
61 return api_key.lower() not in ["none", "null"]
63 def _resolve_sparse_runtime_config(self) -> SparseRuntimeConfig:
64 try:
65 llm = getattr(get_global_config(), "llm", None) or {}
66 except Exception as e:
67 self.logger.warning(
68 "Failed to read global LLM config for sparse runtime; using defaults",
69 error=str(e),
70 exc_info=True,
71 )
72 llm = {}
73 global_config = {"llm": llm} if isinstance(llm, dict) else {}
74 return SparseRuntimeConfig.from_global_config(global_config)
76 def _get_collection_vector_capabilities(self) -> CollectionVectorCapabilities:
77 if self._collection_vector_capabilities is not None:
78 return self._collection_vector_capabilities
80 client = self._ensure_client_connected()
81 try:
82 info = client.get_collection(collection_name=self.collection_name)
83 except Exception as e:
84 # Don't cache: a transient outage would otherwise pin every
85 # subsequent upsert to dense-only payload shape even after Qdrant
86 # becomes reachable again, mismatching a hybrid collection schema.
87 self.logger.warning(
88 "Failed to inspect Qdrant collection schema; assuming dense-only",
89 collection=self.collection_name,
90 error=str(e),
91 )
92 return CollectionVectorCapabilities()
94 self._collection_vector_capabilities = parse_collection_capabilities(
95 info, self.sparse_runtime
96 )
97 return self._collection_vector_capabilities
99 def _dense_query_using(self) -> str | None:
100 caps = self._get_collection_vector_capabilities()
101 if caps.has_named_dense:
102 return self.sparse_runtime.dense_vector_name
103 return None
105 def _sparse_upsert_enabled(self) -> bool:
106 if not self.sparse_runtime.enabled:
107 return False
109 caps = self._get_collection_vector_capabilities()
110 if caps.has_named_dense and caps.has_sparse:
111 return True
113 if not self._sparse_fallback_warning_emitted:
114 self.logger.warning(
115 "Sparse vectors requested but collection schema does not support them; falling back to dense-only upserts",
116 collection=self.collection_name,
117 dense_vector_name=self.sparse_runtime.dense_vector_name,
118 sparse_vector_name=self.sparse_runtime.sparse_vector_name,
119 )
120 self._sparse_fallback_warning_emitted = True
121 return False
123 def build_point_vector(self, dense_embedding: list[float], text: str) -> object:
124 """Build the point vector payload for upsert.
126 Three shapes are possible depending on the live collection schema:
127 - dense+sparse named dict (hybrid-ready collection),
128 - dense-only named dict (legacy named-vector collection),
129 - raw dense list (legacy unnamed collection).
130 """
131 if self._sparse_upsert_enabled():
132 return self._build_hybrid_payload(dense_embedding, text)
133 return self._build_dense_payload(dense_embedding)
135 def _build_dense_payload(self, dense_embedding: list[float]) -> object:
136 """Return dense-only payload using the named-vector shape if the collection requires it."""
137 if self._dense_query_using() is not None:
138 return {self.sparse_runtime.dense_vector_name: dense_embedding}
139 return dense_embedding
141 def _build_hybrid_payload(self, dense_embedding: list[float], text: str) -> object:
142 """Return dense+sparse payload, with a dense-only fallback on encode failure."""
143 try:
144 sparse = get_sparse_encoder(self.sparse_runtime.model).encode_document(text)
145 except Exception as e:
146 self.logger.warning(
147 "Failed to generate sparse vectors; falling back to dense-only upsert",
148 error=str(e),
149 )
150 return self._build_dense_payload(dense_embedding)
152 if sparse.is_empty():
153 return {self.sparse_runtime.dense_vector_name: dense_embedding}
154 return {
155 self.sparse_runtime.dense_vector_name: dense_embedding,
156 self.sparse_runtime.sparse_vector_name: models.SparseVector(
157 indices=sparse.indices, values=sparse.values
158 ),
159 }
161 def connect(self) -> None:
162 """Establish connection to qDrant server."""
163 try:
164 # Ensure HTTPS is used when API key is present, but only for non-local URLs
165 url = self.settings.qdrant_url
166 api_key = (
167 self.settings.qdrant_api_key if self._is_api_key_present() else None
168 )
170 if api_key:
171 parsed_url = urlparse(url)
172 # Only force HTTPS for non-local URLs
173 if parsed_url.scheme != "https" and not any(
174 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"]
175 ):
176 url = url.replace("http://", "https://", 1)
177 self.logger.warning("Forcing HTTPS connection due to API key usage")
179 try:
180 self.client = QdrantClient(
181 url=url,
182 api_key=api_key,
183 timeout=60, # 60 seconds timeout
184 )
185 self.logger.debug("Successfully connected to qDrant")
186 except Exception as e:
187 raise QdrantConnectionError(
188 "Failed to connect to qDrant: Connection error",
189 original_error=str(e),
190 url=url,
191 ) from e
193 except Exception as e:
194 raise QdrantConnectionError(
195 "Failed to connect to qDrant: Unexpected error",
196 original_error=str(e),
197 url=url,
198 ) from e
200 def _ensure_client_connected(self) -> QdrantClient:
201 """Ensure the client is connected before performing operations."""
202 if self.client is None:
203 raise QdrantConnectionError(
204 "Qdrant client is not connected. Please call connect() first."
205 )
206 return cast(QdrantClient, self.client)
208 def create_collection(self) -> None:
209 """Create a new collection if it doesn't exist."""
210 try:
211 client = self._ensure_client_connected()
212 # Check if collection already exists
213 collections = client.get_collections()
214 if any(c.name == self.collection_name for c in collections.collections):
215 self.logger.info(f"Collection {self.collection_name} already exists")
216 return
218 # Get vector size from unified LLM settings first, then legacy embedding
219 vector_size: int | None = None
220 try:
221 global_cfg = get_global_config()
222 llm_settings = getattr(global_cfg, "llm", None)
223 if llm_settings is not None:
224 embeddings_cfg = getattr(llm_settings, "embeddings", None)
225 vs = (
226 getattr(embeddings_cfg, "vector_size", None)
227 if embeddings_cfg is not None
228 else None
229 )
230 if isinstance(vs, int):
231 vector_size = int(vs)
232 except Exception:
233 vector_size = None
235 if vector_size is None:
236 try:
237 legacy_vs = get_global_config().embedding.vector_size
238 if isinstance(legacy_vs, int):
239 vector_size = int(legacy_vs)
240 except Exception:
241 vector_size = None
243 if vector_size is None:
244 self.logger.warning(
245 "No vector_size specified in config; falling back to 1024 (deprecated default). Set global.llm.embeddings.vector_size."
246 )
247 vector_size = 1024
249 # sparse.enabled is a strict declaration. If True, the collection
250 # is created with a sparse vector; failures propagate. If False,
251 # dense-only. Operators on Qdrant servers that don't support sparse
252 # vectors must set sparse.enabled=false explicitly.
253 dense_params = VectorParams(size=vector_size, distance=Distance.COSINE)
254 if self.sparse_runtime.enabled:
255 client.create_collection(
256 collection_name=self.collection_name,
257 vectors_config={
258 self.sparse_runtime.dense_vector_name: dense_params
259 },
260 sparse_vectors_config={
261 self.sparse_runtime.sparse_vector_name: models.SparseVectorParams()
262 },
263 )
264 self.logger.info(
265 "Created Qdrant collection with dense+sparse vectors",
266 collection=self.collection_name,
267 dense_vector_name=self.sparse_runtime.dense_vector_name,
268 sparse_vector_name=self.sparse_runtime.sparse_vector_name,
269 sparse_model=self.sparse_runtime.model,
270 )
271 else:
272 client.create_collection(
273 collection_name=self.collection_name,
274 vectors_config=dense_params,
275 )
277 self._collection_vector_capabilities = CollectionVectorCapabilities(
278 has_named_dense=self.sparse_runtime.enabled,
279 has_sparse=self.sparse_runtime.enabled,
280 )
282 # Create payload indexes for optimal search performance
283 indexes_to_create = [
284 # Essential performance indexes
285 (
286 "document_id",
287 {"type": "keyword"},
288 ), # Existing index, kept for backward compatibility
289 (
290 "project_id",
291 {"type": "keyword"},
292 ), # Critical for multi-tenant filtering
293 ("source_type", {"type": "keyword"}), # Document type filtering
294 ("source", {"type": "keyword"}), # Source path filtering
295 ("title", {"type": "keyword"}), # Title-based search and filtering
296 ("created_at", {"type": "keyword"}), # Temporal filtering
297 ("updated_at", {"type": "keyword"}), # Temporal filtering
298 # Secondary performance indexes
299 ("is_attachment", {"type": "bool"}), # Attachment filtering
300 (
301 "parent_document_id",
302 {"type": "keyword"},
303 ), # Hierarchical relationships
304 ("original_file_type", {"type": "keyword"}), # File type filtering
305 ("is_converted", {"type": "bool"}), # Conversion status filtering
306 ]
308 # Create indexes with proper error handling
309 created_indexes = []
310 failed_indexes = []
312 for field_name, field_schema in indexes_to_create:
313 try:
314 client.create_payload_index(
315 collection_name=self.collection_name,
316 field_name=field_name,
317 field_schema=field_schema, # type: ignore
318 )
319 created_indexes.append(field_name)
320 self.logger.debug(f"Created payload index for field: {field_name}")
321 except Exception as e:
322 failed_indexes.append((field_name, str(e)))
323 self.logger.warning(
324 f"Failed to create index for {field_name}", error=str(e)
325 )
327 # Log index creation summary
328 self.logger.info(
329 f"Collection {self.collection_name} created with indexes",
330 created_indexes=created_indexes,
331 failed_indexes=(
332 [name for name, _ in failed_indexes] if failed_indexes else None
333 ),
334 total_indexes_created=len(created_indexes),
335 )
337 if failed_indexes:
338 self.logger.warning(
339 "Some indexes failed to create but collection is functional",
340 failed_details=failed_indexes,
341 )
342 except Exception as e:
343 self.logger.error("Failed to create collection", error=str(e))
344 raise
346 async def upsert_points(self, points: list[models.PointStruct]) -> None:
347 """Upsert points into the collection.
349 Args:
350 points: List of points to upsert
351 """
352 self.logger.debug(
353 "Upserting points",
354 extra={"point_count": len(points), "collection": self.collection_name},
355 )
357 try:
358 client = self._ensure_client_connected()
359 await asyncio.to_thread(
360 client.upsert, collection_name=self.collection_name, points=points
361 )
362 self.logger.debug(
363 "Successfully upserted points",
364 extra={"point_count": len(points), "collection": self.collection_name},
365 )
366 except Exception as e:
367 self.logger.error(
368 "Failed to upsert points",
369 extra={
370 "error": str(e),
371 "point_count": len(points),
372 "collection": self.collection_name,
373 },
374 )
375 raise
377 def search(
378 self, query_vector: list[float], limit: int = 5
379 ) -> list[models.ScoredPoint]:
380 """Search for similar vectors in the collection."""
381 try:
382 client = self._ensure_client_connected()
383 query_kwargs: dict[str, Any] = {
384 "collection_name": self.collection_name,
385 "query": query_vector,
386 "limit": limit,
387 }
388 using = self._dense_query_using()
389 if using:
390 query_kwargs["using"] = using
391 # Use query_points API (qdrant-client 1.10+)
392 query_response = client.query_points(**query_kwargs)
393 return query_response.points
394 except Exception as e:
395 logger.error("Failed to search collection", error=str(e))
396 raise
398 def search_with_project_filter(
399 self, query_vector: list[float], project_ids: list[str], limit: int = 5
400 ) -> list[models.ScoredPoint]:
401 """Search for similar vectors in the collection with project filtering.
403 Args:
404 query_vector: Query vector for similarity search
405 project_ids: List of project IDs to filter by
406 limit: Maximum number of results to return
408 Returns:
409 List of scored points matching the query and project filter
410 """
411 try:
412 client = self._ensure_client_connected()
414 # Build project filter
415 project_filter = models.Filter(
416 must=[
417 models.FieldCondition(
418 key="project_id", match=models.MatchAny(any=project_ids)
419 )
420 ]
421 )
423 query_kwargs: dict[str, Any] = {
424 "collection_name": self.collection_name,
425 "query": query_vector,
426 "query_filter": project_filter,
427 "limit": limit,
428 }
429 using = self._dense_query_using()
430 if using:
431 query_kwargs["using"] = using
432 # Use query_points API (qdrant-client 1.10+)
433 query_response = client.query_points(**query_kwargs)
434 return query_response.points
435 except Exception as e:
436 logger.error(
437 "Failed to search collection with project filter",
438 error=str(e),
439 project_ids=project_ids,
440 )
441 raise
443 def get_project_collections(self) -> dict[str, str]:
444 """Get mapping of project IDs to their collection names.
446 Returns:
447 Dictionary mapping project_id to collection_name
448 """
449 try:
450 client = self._ensure_client_connected()
452 # Scroll through all points to get unique project-collection mappings
453 scroll_result = client.scroll(
454 collection_name=self.collection_name,
455 limit=10000, # Large limit to get all unique projects
456 with_payload=True,
457 with_vectors=False,
458 )
460 project_collections = {}
461 for point in scroll_result[0]:
462 if point.payload:
463 project_id = point.payload.get("project_id")
464 collection_name = point.payload.get("collection_name")
465 if project_id and collection_name:
466 project_collections[project_id] = collection_name
468 return project_collections
469 except Exception as e:
470 logger.error("Failed to get project collections", error=str(e))
471 raise
473 def delete_collection(self) -> None:
474 """Delete the collection."""
475 try:
476 client = self._ensure_client_connected()
477 client.delete_collection(collection_name=self.collection_name)
478 logger.debug("Collection deleted", collection=self.collection_name)
479 except Exception as e:
480 logger.error("Failed to delete collection", error=str(e))
481 raise
483 async def delete_points_by_document_id(self, document_ids: list[str]) -> None:
484 """Delete points from the collection by document ID.
486 Args:
487 document_ids: List of document IDs to delete
488 """
489 self.logger.debug(
490 "Deleting points by document ID",
491 extra={
492 "document_count": len(document_ids),
493 "collection": self.collection_name,
494 },
495 )
497 try:
498 client = self._ensure_client_connected()
499 await asyncio.to_thread(
500 client.delete,
501 collection_name=self.collection_name,
502 points_selector=models.Filter(
503 must=[
504 models.FieldCondition(
505 key="document_id", match=models.MatchAny(any=document_ids)
506 )
507 ]
508 ),
509 )
510 self.logger.debug(
511 "Successfully deleted points",
512 extra={
513 "document_count": len(document_ids),
514 "collection": self.collection_name,
515 },
516 )
517 except Exception as e:
518 self.logger.error(
519 "Failed to delete points",
520 extra={
521 "error": str(e),
522 "document_count": len(document_ids),
523 "collection": self.collection_name,
524 },
525 )
526 raise