Coverage for src/qdrant_loader/core/qdrant_manager.py: 82%
123 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:50 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:50 +0000
1import asyncio
2from typing import cast
3from urllib.parse import urlparse
5from qdrant_client import QdrantClient
6from qdrant_client.http import models
7from qdrant_client.http.models import (
8 Distance,
9 VectorParams,
10)
12from ..config import Settings, get_global_config, get_settings
13from ..utils.logging import LoggingConfig
15logger = LoggingConfig.get_logger(__name__)
18class QdrantConnectionError(Exception):
19 """Custom exception for Qdrant connection errors."""
21 def __init__(
22 self, message: str, original_error: str | None = None, url: str | None = None
23 ):
24 self.message = message
25 self.original_error = original_error
26 self.url = url
27 super().__init__(self.message)
30class QdrantManager:
31 def __init__(self, settings: Settings | None = None):
32 """Initialize the qDrant manager.
34 Args:
35 settings: The application settings
36 """
37 self.settings = settings or get_settings()
38 self.client = None
39 self.collection_name = self.settings.qdrant_collection_name
40 self.logger = LoggingConfig.get_logger(__name__)
41 self.batch_size = get_global_config().embedding.batch_size
42 self.connect()
44 def _is_api_key_present(self) -> bool:
45 """
46 Check if a valid API key is present.
47 Returns True if the API key is a non-empty string that is not 'None' or 'null'.
48 """
49 api_key = self.settings.qdrant_api_key
50 if not api_key: # Catches None, empty string, etc.
51 return False
52 return api_key.lower() not in ["none", "null"]
54 def connect(self) -> None:
55 """Establish connection to qDrant server."""
56 try:
57 # Ensure HTTPS is used when API key is present, but only for non-local URLs
58 url = self.settings.qdrant_url
59 api_key = (
60 self.settings.qdrant_api_key if self._is_api_key_present() else None
61 )
63 if api_key:
64 parsed_url = urlparse(url)
65 # Only force HTTPS for non-local URLs
66 if parsed_url.scheme != "https" and not any(
67 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"]
68 ):
69 url = url.replace("http://", "https://", 1)
70 self.logger.warning("Forcing HTTPS connection due to API key usage")
72 try:
73 self.client = QdrantClient(
74 url=url,
75 api_key=api_key,
76 timeout=60, # 60 seconds timeout
77 )
78 self.logger.debug("Successfully connected to qDrant")
79 except Exception as e:
80 raise QdrantConnectionError(
81 "Failed to connect to qDrant: Connection error",
82 original_error=str(e),
83 url=url,
84 ) from e
86 except Exception as e:
87 raise QdrantConnectionError(
88 "Failed to connect to qDrant: Unexpected error",
89 original_error=str(e),
90 url=url,
91 ) from e
93 def _ensure_client_connected(self) -> QdrantClient:
94 """Ensure the client is connected before performing operations."""
95 if self.client is None:
96 raise QdrantConnectionError(
97 "Qdrant client is not connected. Please call connect() first."
98 )
99 return cast(QdrantClient, self.client)
101 def create_collection(self) -> None:
102 """Create a new collection if it doesn't exist."""
103 try:
104 client = self._ensure_client_connected()
105 # Check if collection already exists
106 collections = client.get_collections()
107 if any(c.name == self.collection_name for c in collections.collections):
108 self.logger.info(f"Collection {self.collection_name} already exists")
109 return
111 # Get vector size from configuration
112 vector_size = get_global_config().embedding.vector_size
113 if not vector_size:
114 self.logger.warning(
115 "No vector_size specified in config, defaulting to 1536"
116 )
117 vector_size = 1536
119 # Create collection with basic configuration
120 client.create_collection(
121 collection_name=self.collection_name,
122 vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
123 )
125 # Create index for document_id field
126 client.create_payload_index(
127 collection_name=self.collection_name,
128 field_name="document_id",
129 field_schema={"type": "keyword"}, # type: ignore
130 )
132 self.logger.debug(f"Collection {self.collection_name} created successfully")
133 except Exception as e:
134 self.logger.error("Failed to create collection", error=str(e))
135 raise
137 async def upsert_points(self, points: list[models.PointStruct]) -> None:
138 """Upsert points into the collection.
140 Args:
141 points: List of points to upsert
142 """
143 self.logger.debug(
144 "Upserting points",
145 extra={"point_count": len(points), "collection": self.collection_name},
146 )
148 try:
149 client = self._ensure_client_connected()
150 await asyncio.to_thread(
151 client.upsert, collection_name=self.collection_name, points=points
152 )
153 self.logger.debug(
154 "Successfully upserted points",
155 extra={"point_count": len(points), "collection": self.collection_name},
156 )
157 except Exception as e:
158 self.logger.error(
159 "Failed to upsert points",
160 extra={
161 "error": str(e),
162 "point_count": len(points),
163 "collection": self.collection_name,
164 },
165 )
166 raise
168 def search(
169 self, query_vector: list[float], limit: int = 5
170 ) -> list[models.ScoredPoint]:
171 """Search for similar vectors in the collection."""
172 try:
173 client = self._ensure_client_connected()
174 search_result = client.search(
175 collection_name=self.collection_name,
176 query_vector=query_vector,
177 limit=limit,
178 )
179 return search_result
180 except Exception as e:
181 logger.error("Failed to search collection", error=str(e))
182 raise
184 def search_with_project_filter(
185 self, query_vector: list[float], project_ids: list[str], limit: int = 5
186 ) -> list[models.ScoredPoint]:
187 """Search for similar vectors in the collection with project filtering.
189 Args:
190 query_vector: Query vector for similarity search
191 project_ids: List of project IDs to filter by
192 limit: Maximum number of results to return
194 Returns:
195 List of scored points matching the query and project filter
196 """
197 try:
198 client = self._ensure_client_connected()
200 # Build project filter
201 project_filter = models.Filter(
202 must=[
203 models.FieldCondition(
204 key="project_id", match=models.MatchAny(any=project_ids)
205 )
206 ]
207 )
209 search_result = client.search(
210 collection_name=self.collection_name,
211 query_vector=query_vector,
212 query_filter=project_filter,
213 limit=limit,
214 )
215 return search_result
216 except Exception as e:
217 logger.error(
218 "Failed to search collection with project filter",
219 error=str(e),
220 project_ids=project_ids,
221 )
222 raise
224 def get_project_collections(self) -> dict[str, str]:
225 """Get mapping of project IDs to their collection names.
227 Returns:
228 Dictionary mapping project_id to collection_name
229 """
230 try:
231 client = self._ensure_client_connected()
233 # Scroll through all points to get unique project-collection mappings
234 scroll_result = client.scroll(
235 collection_name=self.collection_name,
236 limit=10000, # Large limit to get all unique projects
237 with_payload=True,
238 with_vectors=False,
239 )
241 project_collections = {}
242 for point in scroll_result[0]:
243 if point.payload:
244 project_id = point.payload.get("project_id")
245 collection_name = point.payload.get("collection_name")
246 if project_id and collection_name:
247 project_collections[project_id] = collection_name
249 return project_collections
250 except Exception as e:
251 logger.error("Failed to get project collections", error=str(e))
252 raise
254 def delete_collection(self) -> None:
255 """Delete the collection."""
256 try:
257 client = self._ensure_client_connected()
258 client.delete_collection(collection_name=self.collection_name)
259 logger.debug("Collection deleted", collection=self.collection_name)
260 except Exception as e:
261 logger.error("Failed to delete collection", error=str(e))
262 raise
264 async def delete_points_by_document_id(self, document_ids: list[str]) -> None:
265 """Delete points from the collection by document ID.
267 Args:
268 document_ids: List of document IDs to delete
269 """
270 self.logger.debug(
271 "Deleting points by document ID",
272 extra={
273 "document_count": len(document_ids),
274 "collection": self.collection_name,
275 },
276 )
278 try:
279 client = self._ensure_client_connected()
280 await asyncio.to_thread(
281 client.delete,
282 collection_name=self.collection_name,
283 points_selector=models.Filter(
284 must=[
285 models.FieldCondition(
286 key="document_id", match=models.MatchAny(any=document_ids)
287 )
288 ]
289 ),
290 )
291 self.logger.debug(
292 "Successfully deleted points",
293 extra={
294 "document_count": len(document_ids),
295 "collection": self.collection_name,
296 },
297 )
298 except Exception as e:
299 self.logger.error(
300 "Failed to delete points",
301 extra={
302 "error": str(e),
303 "document_count": len(document_ids),
304 "collection": self.collection_name,
305 },
306 )
307 raise