Coverage for src / qdrant_loader / core / qdrant_manager.py: 80%
152 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 09:46 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 09:46 +0000
1import asyncio
2from typing import cast
3from urllib.parse import urlparse
5from qdrant_client import QdrantClient
6from qdrant_client.http import models
7from qdrant_client.http.models import (
8 Distance,
9 VectorParams,
10)
12from ..config import Settings, get_global_config, get_settings
13from ..utils.logging import LoggingConfig
15logger = LoggingConfig.get_logger(__name__)
18class QdrantConnectionError(Exception):
19 """Custom exception for Qdrant connection errors."""
21 def __init__(
22 self, message: str, original_error: str | None = None, url: str | None = None
23 ):
24 self.message = message
25 self.original_error = original_error
26 self.url = url
27 super().__init__(self.message)
30class QdrantManager:
31 def __init__(self, settings: Settings | None = None):
32 """Initialize the qDrant manager.
34 Args:
35 settings: The application settings
36 """
37 self.settings = settings or get_settings()
38 self.client = None
39 self.collection_name = self.settings.qdrant_collection_name
40 self.logger = LoggingConfig.get_logger(__name__)
41 self.batch_size = get_global_config().embedding.batch_size
42 self.connect()
44 def _is_api_key_present(self) -> bool:
45 """
46 Check if a valid API key is present.
47 Returns True if the API key is a non-empty string that is not 'None' or 'null'.
48 """
49 api_key = self.settings.qdrant_api_key
50 if not api_key: # Catches None, empty string, etc.
51 return False
52 return api_key.lower() not in ["none", "null"]
54 def connect(self) -> None:
55 """Establish connection to qDrant server."""
56 try:
57 # Ensure HTTPS is used when API key is present, but only for non-local URLs
58 url = self.settings.qdrant_url
59 api_key = (
60 self.settings.qdrant_api_key if self._is_api_key_present() else None
61 )
63 if api_key:
64 parsed_url = urlparse(url)
65 # Only force HTTPS for non-local URLs
66 if parsed_url.scheme != "https" and not any(
67 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"]
68 ):
69 url = url.replace("http://", "https://", 1)
70 self.logger.warning("Forcing HTTPS connection due to API key usage")
72 try:
73 self.client = QdrantClient(
74 url=url,
75 api_key=api_key,
76 timeout=60, # 60 seconds timeout
77 )
78 self.logger.debug("Successfully connected to qDrant")
79 except Exception as e:
80 raise QdrantConnectionError(
81 "Failed to connect to qDrant: Connection error",
82 original_error=str(e),
83 url=url,
84 ) from e
86 except Exception as e:
87 raise QdrantConnectionError(
88 "Failed to connect to qDrant: Unexpected error",
89 original_error=str(e),
90 url=url,
91 ) from e
93 def _ensure_client_connected(self) -> QdrantClient:
94 """Ensure the client is connected before performing operations."""
95 if self.client is None:
96 raise QdrantConnectionError(
97 "Qdrant client is not connected. Please call connect() first."
98 )
99 return cast(QdrantClient, self.client)
101 def create_collection(self) -> None:
102 """Create a new collection if it doesn't exist."""
103 try:
104 client = self._ensure_client_connected()
105 # Check if collection already exists
106 collections = client.get_collections()
107 if any(c.name == self.collection_name for c in collections.collections):
108 self.logger.info(f"Collection {self.collection_name} already exists")
109 return
111 # Get vector size from unified LLM settings first, then legacy embedding
112 vector_size: int | None = None
113 try:
114 global_cfg = get_global_config()
115 llm_settings = getattr(global_cfg, "llm", None)
116 if llm_settings is not None:
117 embeddings_cfg = getattr(llm_settings, "embeddings", None)
118 vs = (
119 getattr(embeddings_cfg, "vector_size", None)
120 if embeddings_cfg is not None
121 else None
122 )
123 if isinstance(vs, int):
124 vector_size = int(vs)
125 except Exception:
126 vector_size = None
128 if vector_size is None:
129 try:
130 legacy_vs = get_global_config().embedding.vector_size
131 if isinstance(legacy_vs, int):
132 vector_size = int(legacy_vs)
133 except Exception:
134 vector_size = None
136 if vector_size is None:
137 self.logger.warning(
138 "No vector_size specified in config; falling back to 1536 (deprecated default). Set global.llm.embeddings.vector_size."
139 )
140 vector_size = 1536
142 # Create collection with basic configuration
143 client.create_collection(
144 collection_name=self.collection_name,
145 vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
146 )
148 # Create payload indexes for optimal search performance
149 indexes_to_create = [
150 # Essential performance indexes
151 (
152 "document_id",
153 {"type": "keyword"},
154 ), # Existing index, kept for backward compatibility
155 (
156 "project_id",
157 {"type": "keyword"},
158 ), # Critical for multi-tenant filtering
159 ("source_type", {"type": "keyword"}), # Document type filtering
160 ("source", {"type": "keyword"}), # Source path filtering
161 ("title", {"type": "keyword"}), # Title-based search and filtering
162 ("created_at", {"type": "keyword"}), # Temporal filtering
163 ("updated_at", {"type": "keyword"}), # Temporal filtering
164 # Secondary performance indexes
165 ("is_attachment", {"type": "bool"}), # Attachment filtering
166 (
167 "parent_document_id",
168 {"type": "keyword"},
169 ), # Hierarchical relationships
170 ("original_file_type", {"type": "keyword"}), # File type filtering
171 ("is_converted", {"type": "bool"}), # Conversion status filtering
172 ]
174 # Create indexes with proper error handling
175 created_indexes = []
176 failed_indexes = []
178 for field_name, field_schema in indexes_to_create:
179 try:
180 client.create_payload_index(
181 collection_name=self.collection_name,
182 field_name=field_name,
183 field_schema=field_schema, # type: ignore
184 )
185 created_indexes.append(field_name)
186 self.logger.debug(f"Created payload index for field: {field_name}")
187 except Exception as e:
188 failed_indexes.append((field_name, str(e)))
189 self.logger.warning(
190 f"Failed to create index for {field_name}", error=str(e)
191 )
193 # Log index creation summary
194 self.logger.info(
195 f"Collection {self.collection_name} created with indexes",
196 created_indexes=created_indexes,
197 failed_indexes=(
198 [name for name, _ in failed_indexes] if failed_indexes else None
199 ),
200 total_indexes_created=len(created_indexes),
201 )
203 if failed_indexes:
204 self.logger.warning(
205 "Some indexes failed to create but collection is functional",
206 failed_details=failed_indexes,
207 )
208 except Exception as e:
209 self.logger.error("Failed to create collection", error=str(e))
210 raise
212 async def upsert_points(self, points: list[models.PointStruct]) -> None:
213 """Upsert points into the collection.
215 Args:
216 points: List of points to upsert
217 """
218 self.logger.debug(
219 "Upserting points",
220 extra={"point_count": len(points), "collection": self.collection_name},
221 )
223 try:
224 client = self._ensure_client_connected()
225 await asyncio.to_thread(
226 client.upsert, collection_name=self.collection_name, points=points
227 )
228 self.logger.debug(
229 "Successfully upserted points",
230 extra={"point_count": len(points), "collection": self.collection_name},
231 )
232 except Exception as e:
233 self.logger.error(
234 "Failed to upsert points",
235 extra={
236 "error": str(e),
237 "point_count": len(points),
238 "collection": self.collection_name,
239 },
240 )
241 raise
243 def search(
244 self, query_vector: list[float], limit: int = 5
245 ) -> list[models.ScoredPoint]:
246 """Search for similar vectors in the collection."""
247 try:
248 client = self._ensure_client_connected()
249 # Use query_points API (qdrant-client 1.10+)
250 query_response = client.query_points(
251 collection_name=self.collection_name,
252 query=query_vector,
253 limit=limit,
254 )
255 return query_response.points
256 except Exception as e:
257 logger.error("Failed to search collection", error=str(e))
258 raise
260 def search_with_project_filter(
261 self, query_vector: list[float], project_ids: list[str], limit: int = 5
262 ) -> list[models.ScoredPoint]:
263 """Search for similar vectors in the collection with project filtering.
265 Args:
266 query_vector: Query vector for similarity search
267 project_ids: List of project IDs to filter by
268 limit: Maximum number of results to return
270 Returns:
271 List of scored points matching the query and project filter
272 """
273 try:
274 client = self._ensure_client_connected()
276 # Build project filter
277 project_filter = models.Filter(
278 must=[
279 models.FieldCondition(
280 key="project_id", match=models.MatchAny(any=project_ids)
281 )
282 ]
283 )
285 # Use query_points API (qdrant-client 1.10+)
286 query_response = client.query_points(
287 collection_name=self.collection_name,
288 query=query_vector,
289 query_filter=project_filter,
290 limit=limit,
291 )
292 return query_response.points
293 except Exception as e:
294 logger.error(
295 "Failed to search collection with project filter",
296 error=str(e),
297 project_ids=project_ids,
298 )
299 raise
301 def get_project_collections(self) -> dict[str, str]:
302 """Get mapping of project IDs to their collection names.
304 Returns:
305 Dictionary mapping project_id to collection_name
306 """
307 try:
308 client = self._ensure_client_connected()
310 # Scroll through all points to get unique project-collection mappings
311 scroll_result = client.scroll(
312 collection_name=self.collection_name,
313 limit=10000, # Large limit to get all unique projects
314 with_payload=True,
315 with_vectors=False,
316 )
318 project_collections = {}
319 for point in scroll_result[0]:
320 if point.payload:
321 project_id = point.payload.get("project_id")
322 collection_name = point.payload.get("collection_name")
323 if project_id and collection_name:
324 project_collections[project_id] = collection_name
326 return project_collections
327 except Exception as e:
328 logger.error("Failed to get project collections", error=str(e))
329 raise
331 def delete_collection(self) -> None:
332 """Delete the collection."""
333 try:
334 client = self._ensure_client_connected()
335 client.delete_collection(collection_name=self.collection_name)
336 logger.debug("Collection deleted", collection=self.collection_name)
337 except Exception as e:
338 logger.error("Failed to delete collection", error=str(e))
339 raise
341 async def delete_points_by_document_id(self, document_ids: list[str]) -> None:
342 """Delete points from the collection by document ID.
344 Args:
345 document_ids: List of document IDs to delete
346 """
347 self.logger.debug(
348 "Deleting points by document ID",
349 extra={
350 "document_count": len(document_ids),
351 "collection": self.collection_name,
352 },
353 )
355 try:
356 client = self._ensure_client_connected()
357 await asyncio.to_thread(
358 client.delete,
359 collection_name=self.collection_name,
360 points_selector=models.Filter(
361 must=[
362 models.FieldCondition(
363 key="document_id", match=models.MatchAny(any=document_ids)
364 )
365 ]
366 ),
367 )
368 self.logger.debug(
369 "Successfully deleted points",
370 extra={
371 "document_count": len(document_ids),
372 "collection": self.collection_name,
373 },
374 )
375 except Exception as e:
376 self.logger.error(
377 "Failed to delete points",
378 extra={
379 "error": str(e),
380 "document_count": len(document_ids),
381 "collection": self.collection_name,
382 },
383 )
384 raise