Coverage for src/qdrant_loader/core/qdrant_manager.py: 80%
152 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:05 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:05 +0000
1import asyncio
2from typing import cast
3from urllib.parse import urlparse
5from qdrant_client import QdrantClient
6from qdrant_client.http import models
7from qdrant_client.http.models import (
8 Distance,
9 VectorParams,
10)
12from ..config import Settings, get_global_config, get_settings
13from ..utils.logging import LoggingConfig
15logger = LoggingConfig.get_logger(__name__)
18class QdrantConnectionError(Exception):
19 """Custom exception for Qdrant connection errors."""
21 def __init__(
22 self, message: str, original_error: str | None = None, url: str | None = None
23 ):
24 self.message = message
25 self.original_error = original_error
26 self.url = url
27 super().__init__(self.message)
30class QdrantManager:
31 def __init__(self, settings: Settings | None = None):
32 """Initialize the qDrant manager.
34 Args:
35 settings: The application settings
36 """
37 self.settings = settings or get_settings()
38 self.client = None
39 self.collection_name = self.settings.qdrant_collection_name
40 self.logger = LoggingConfig.get_logger(__name__)
41 self.batch_size = get_global_config().embedding.batch_size
42 self.connect()
44 def _is_api_key_present(self) -> bool:
45 """
46 Check if a valid API key is present.
47 Returns True if the API key is a non-empty string that is not 'None' or 'null'.
48 """
49 api_key = self.settings.qdrant_api_key
50 if not api_key: # Catches None, empty string, etc.
51 return False
52 return api_key.lower() not in ["none", "null"]
54 def connect(self) -> None:
55 """Establish connection to qDrant server."""
56 try:
57 # Ensure HTTPS is used when API key is present, but only for non-local URLs
58 url = self.settings.qdrant_url
59 api_key = (
60 self.settings.qdrant_api_key if self._is_api_key_present() else None
61 )
63 if api_key:
64 parsed_url = urlparse(url)
65 # Only force HTTPS for non-local URLs
66 if parsed_url.scheme != "https" and not any(
67 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"]
68 ):
69 url = url.replace("http://", "https://", 1)
70 self.logger.warning("Forcing HTTPS connection due to API key usage")
72 try:
73 self.client = QdrantClient(
74 url=url,
75 api_key=api_key,
76 timeout=60, # 60 seconds timeout
77 )
78 self.logger.debug("Successfully connected to qDrant")
79 except Exception as e:
80 raise QdrantConnectionError(
81 "Failed to connect to qDrant: Connection error",
82 original_error=str(e),
83 url=url,
84 ) from e
86 except Exception as e:
87 raise QdrantConnectionError(
88 "Failed to connect to qDrant: Unexpected error",
89 original_error=str(e),
90 url=url,
91 ) from e
93 def _ensure_client_connected(self) -> QdrantClient:
94 """Ensure the client is connected before performing operations."""
95 if self.client is None:
96 raise QdrantConnectionError(
97 "Qdrant client is not connected. Please call connect() first."
98 )
99 return cast(QdrantClient, self.client)
101 def create_collection(self) -> None:
102 """Create a new collection if it doesn't exist."""
103 try:
104 client = self._ensure_client_connected()
105 # Check if collection already exists
106 collections = client.get_collections()
107 if any(c.name == self.collection_name for c in collections.collections):
108 self.logger.info(f"Collection {self.collection_name} already exists")
109 return
111 # Get vector size from unified LLM settings first, then legacy embedding
112 vector_size: int | None = None
113 try:
114 global_cfg = get_global_config()
115 llm_settings = getattr(global_cfg, "llm", None)
116 if llm_settings is not None:
117 embeddings_cfg = getattr(llm_settings, "embeddings", None)
118 vs = (
119 getattr(embeddings_cfg, "vector_size", None)
120 if embeddings_cfg is not None
121 else None
122 )
123 if isinstance(vs, int):
124 vector_size = int(vs)
125 except Exception:
126 vector_size = None
128 if vector_size is None:
129 try:
130 legacy_vs = get_global_config().embedding.vector_size
131 if isinstance(legacy_vs, int):
132 vector_size = int(legacy_vs)
133 except Exception:
134 vector_size = None
136 if vector_size is None:
137 self.logger.warning(
138 "No vector_size specified in config; falling back to 1536 (deprecated default). Set global.llm.embeddings.vector_size."
139 )
140 vector_size = 1536
142 # Create collection with basic configuration
143 client.create_collection(
144 collection_name=self.collection_name,
145 vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
146 )
148 # Create payload indexes for optimal search performance
149 indexes_to_create = [
150 # Essential performance indexes
151 (
152 "document_id",
153 {"type": "keyword"},
154 ), # Existing index, kept for backward compatibility
155 (
156 "project_id",
157 {"type": "keyword"},
158 ), # Critical for multi-tenant filtering
159 ("source_type", {"type": "keyword"}), # Document type filtering
160 ("source", {"type": "keyword"}), # Source path filtering
161 ("title", {"type": "keyword"}), # Title-based search and filtering
162 ("created_at", {"type": "keyword"}), # Temporal filtering
163 ("updated_at", {"type": "keyword"}), # Temporal filtering
164 # Secondary performance indexes
165 ("is_attachment", {"type": "bool"}), # Attachment filtering
166 (
167 "parent_document_id",
168 {"type": "keyword"},
169 ), # Hierarchical relationships
170 ("original_file_type", {"type": "keyword"}), # File type filtering
171 ("is_converted", {"type": "bool"}), # Conversion status filtering
172 ]
174 # Create indexes with proper error handling
175 created_indexes = []
176 failed_indexes = []
178 for field_name, field_schema in indexes_to_create:
179 try:
180 client.create_payload_index(
181 collection_name=self.collection_name,
182 field_name=field_name,
183 field_schema=field_schema, # type: ignore
184 )
185 created_indexes.append(field_name)
186 self.logger.debug(f"Created payload index for field: {field_name}")
187 except Exception as e:
188 failed_indexes.append((field_name, str(e)))
189 self.logger.warning(
190 f"Failed to create index for {field_name}", error=str(e)
191 )
193 # Log index creation summary
194 self.logger.info(
195 f"Collection {self.collection_name} created with indexes",
196 created_indexes=created_indexes,
197 failed_indexes=(
198 [name for name, _ in failed_indexes] if failed_indexes else None
199 ),
200 total_indexes_created=len(created_indexes),
201 )
203 if failed_indexes:
204 self.logger.warning(
205 "Some indexes failed to create but collection is functional",
206 failed_details=failed_indexes,
207 )
208 except Exception as e:
209 self.logger.error("Failed to create collection", error=str(e))
210 raise
212 async def upsert_points(self, points: list[models.PointStruct]) -> None:
213 """Upsert points into the collection.
215 Args:
216 points: List of points to upsert
217 """
218 self.logger.debug(
219 "Upserting points",
220 extra={"point_count": len(points), "collection": self.collection_name},
221 )
223 try:
224 client = self._ensure_client_connected()
225 await asyncio.to_thread(
226 client.upsert, collection_name=self.collection_name, points=points
227 )
228 self.logger.debug(
229 "Successfully upserted points",
230 extra={"point_count": len(points), "collection": self.collection_name},
231 )
232 except Exception as e:
233 self.logger.error(
234 "Failed to upsert points",
235 extra={
236 "error": str(e),
237 "point_count": len(points),
238 "collection": self.collection_name,
239 },
240 )
241 raise
243 def search(
244 self, query_vector: list[float], limit: int = 5
245 ) -> list[models.ScoredPoint]:
246 """Search for similar vectors in the collection."""
247 try:
248 client = self._ensure_client_connected()
249 search_result = client.search(
250 collection_name=self.collection_name,
251 query_vector=query_vector,
252 limit=limit,
253 )
254 return search_result
255 except Exception as e:
256 logger.error("Failed to search collection", error=str(e))
257 raise
259 def search_with_project_filter(
260 self, query_vector: list[float], project_ids: list[str], limit: int = 5
261 ) -> list[models.ScoredPoint]:
262 """Search for similar vectors in the collection with project filtering.
264 Args:
265 query_vector: Query vector for similarity search
266 project_ids: List of project IDs to filter by
267 limit: Maximum number of results to return
269 Returns:
270 List of scored points matching the query and project filter
271 """
272 try:
273 client = self._ensure_client_connected()
275 # Build project filter
276 project_filter = models.Filter(
277 must=[
278 models.FieldCondition(
279 key="project_id", match=models.MatchAny(any=project_ids)
280 )
281 ]
282 )
284 search_result = client.search(
285 collection_name=self.collection_name,
286 query_vector=query_vector,
287 query_filter=project_filter,
288 limit=limit,
289 )
290 return search_result
291 except Exception as e:
292 logger.error(
293 "Failed to search collection with project filter",
294 error=str(e),
295 project_ids=project_ids,
296 )
297 raise
299 def get_project_collections(self) -> dict[str, str]:
300 """Get mapping of project IDs to their collection names.
302 Returns:
303 Dictionary mapping project_id to collection_name
304 """
305 try:
306 client = self._ensure_client_connected()
308 # Scroll through all points to get unique project-collection mappings
309 scroll_result = client.scroll(
310 collection_name=self.collection_name,
311 limit=10000, # Large limit to get all unique projects
312 with_payload=True,
313 with_vectors=False,
314 )
316 project_collections = {}
317 for point in scroll_result[0]:
318 if point.payload:
319 project_id = point.payload.get("project_id")
320 collection_name = point.payload.get("collection_name")
321 if project_id and collection_name:
322 project_collections[project_id] = collection_name
324 return project_collections
325 except Exception as e:
326 logger.error("Failed to get project collections", error=str(e))
327 raise
329 def delete_collection(self) -> None:
330 """Delete the collection."""
331 try:
332 client = self._ensure_client_connected()
333 client.delete_collection(collection_name=self.collection_name)
334 logger.debug("Collection deleted", collection=self.collection_name)
335 except Exception as e:
336 logger.error("Failed to delete collection", error=str(e))
337 raise
339 async def delete_points_by_document_id(self, document_ids: list[str]) -> None:
340 """Delete points from the collection by document ID.
342 Args:
343 document_ids: List of document IDs to delete
344 """
345 self.logger.debug(
346 "Deleting points by document ID",
347 extra={
348 "document_count": len(document_ids),
349 "collection": self.collection_name,
350 },
351 )
353 try:
354 client = self._ensure_client_connected()
355 await asyncio.to_thread(
356 client.delete,
357 collection_name=self.collection_name,
358 points_selector=models.Filter(
359 must=[
360 models.FieldCondition(
361 key="document_id", match=models.MatchAny(any=document_ids)
362 )
363 ]
364 ),
365 )
366 self.logger.debug(
367 "Successfully deleted points",
368 extra={
369 "document_count": len(document_ids),
370 "collection": self.collection_name,
371 },
372 )
373 except Exception as e:
374 self.logger.error(
375 "Failed to delete points",
376 extra={
377 "error": str(e),
378 "document_count": len(document_ids),
379 "collection": self.collection_name,
380 },
381 )
382 raise