Coverage for src/qdrant_loader/core/qdrant_manager.py: 80%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:05 +0000

1import asyncio 

2from typing import cast 

3from urllib.parse import urlparse 

4 

5from qdrant_client import QdrantClient 

6from qdrant_client.http import models 

7from qdrant_client.http.models import ( 

8 Distance, 

9 VectorParams, 

10) 

11 

12from ..config import Settings, get_global_config, get_settings 

13from ..utils.logging import LoggingConfig 

14 

15logger = LoggingConfig.get_logger(__name__) 

16 

17 

18class QdrantConnectionError(Exception): 

19 """Custom exception for Qdrant connection errors.""" 

20 

21 def __init__( 

22 self, message: str, original_error: str | None = None, url: str | None = None 

23 ): 

24 self.message = message 

25 self.original_error = original_error 

26 self.url = url 

27 super().__init__(self.message) 

28 

29 

30class QdrantManager: 

31 def __init__(self, settings: Settings | None = None): 

32 """Initialize the qDrant manager. 

33 

34 Args: 

35 settings: The application settings 

36 """ 

37 self.settings = settings or get_settings() 

38 self.client = None 

39 self.collection_name = self.settings.qdrant_collection_name 

40 self.logger = LoggingConfig.get_logger(__name__) 

41 self.batch_size = get_global_config().embedding.batch_size 

42 self.connect() 

43 

44 def _is_api_key_present(self) -> bool: 

45 """ 

46 Check if a valid API key is present. 

47 Returns True if the API key is a non-empty string that is not 'None' or 'null'. 

48 """ 

49 api_key = self.settings.qdrant_api_key 

50 if not api_key: # Catches None, empty string, etc. 

51 return False 

52 return api_key.lower() not in ["none", "null"] 

53 

54 def connect(self) -> None: 

55 """Establish connection to qDrant server.""" 

56 try: 

57 # Ensure HTTPS is used when API key is present, but only for non-local URLs 

58 url = self.settings.qdrant_url 

59 api_key = ( 

60 self.settings.qdrant_api_key if self._is_api_key_present() else None 

61 ) 

62 

63 if api_key: 

64 parsed_url = urlparse(url) 

65 # Only force HTTPS for non-local URLs 

66 if parsed_url.scheme != "https" and not any( 

67 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"] 

68 ): 

69 url = url.replace("http://", "https://", 1) 

70 self.logger.warning("Forcing HTTPS connection due to API key usage") 

71 

72 try: 

73 self.client = QdrantClient( 

74 url=url, 

75 api_key=api_key, 

76 timeout=60, # 60 seconds timeout 

77 ) 

78 self.logger.debug("Successfully connected to qDrant") 

79 except Exception as e: 

80 raise QdrantConnectionError( 

81 "Failed to connect to qDrant: Connection error", 

82 original_error=str(e), 

83 url=url, 

84 ) from e 

85 

86 except Exception as e: 

87 raise QdrantConnectionError( 

88 "Failed to connect to qDrant: Unexpected error", 

89 original_error=str(e), 

90 url=url, 

91 ) from e 

92 

93 def _ensure_client_connected(self) -> QdrantClient: 

94 """Ensure the client is connected before performing operations.""" 

95 if self.client is None: 

96 raise QdrantConnectionError( 

97 "Qdrant client is not connected. Please call connect() first." 

98 ) 

99 return cast(QdrantClient, self.client) 

100 

101 def create_collection(self) -> None: 

102 """Create a new collection if it doesn't exist.""" 

103 try: 

104 client = self._ensure_client_connected() 

105 # Check if collection already exists 

106 collections = client.get_collections() 

107 if any(c.name == self.collection_name for c in collections.collections): 

108 self.logger.info(f"Collection {self.collection_name} already exists") 

109 return 

110 

111 # Get vector size from unified LLM settings first, then legacy embedding 

112 vector_size: int | None = None 

113 try: 

114 global_cfg = get_global_config() 

115 llm_settings = getattr(global_cfg, "llm", None) 

116 if llm_settings is not None: 

117 embeddings_cfg = getattr(llm_settings, "embeddings", None) 

118 vs = ( 

119 getattr(embeddings_cfg, "vector_size", None) 

120 if embeddings_cfg is not None 

121 else None 

122 ) 

123 if isinstance(vs, int): 

124 vector_size = int(vs) 

125 except Exception: 

126 vector_size = None 

127 

128 if vector_size is None: 

129 try: 

130 legacy_vs = get_global_config().embedding.vector_size 

131 if isinstance(legacy_vs, int): 

132 vector_size = int(legacy_vs) 

133 except Exception: 

134 vector_size = None 

135 

136 if vector_size is None: 

137 self.logger.warning( 

138 "No vector_size specified in config; falling back to 1536 (deprecated default). Set global.llm.embeddings.vector_size." 

139 ) 

140 vector_size = 1536 

141 

142 # Create collection with basic configuration 

143 client.create_collection( 

144 collection_name=self.collection_name, 

145 vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), 

146 ) 

147 

148 # Create payload indexes for optimal search performance 

149 indexes_to_create = [ 

150 # Essential performance indexes 

151 ( 

152 "document_id", 

153 {"type": "keyword"}, 

154 ), # Existing index, kept for backward compatibility 

155 ( 

156 "project_id", 

157 {"type": "keyword"}, 

158 ), # Critical for multi-tenant filtering 

159 ("source_type", {"type": "keyword"}), # Document type filtering 

160 ("source", {"type": "keyword"}), # Source path filtering 

161 ("title", {"type": "keyword"}), # Title-based search and filtering 

162 ("created_at", {"type": "keyword"}), # Temporal filtering 

163 ("updated_at", {"type": "keyword"}), # Temporal filtering 

164 # Secondary performance indexes 

165 ("is_attachment", {"type": "bool"}), # Attachment filtering 

166 ( 

167 "parent_document_id", 

168 {"type": "keyword"}, 

169 ), # Hierarchical relationships 

170 ("original_file_type", {"type": "keyword"}), # File type filtering 

171 ("is_converted", {"type": "bool"}), # Conversion status filtering 

172 ] 

173 

174 # Create indexes with proper error handling 

175 created_indexes = [] 

176 failed_indexes = [] 

177 

178 for field_name, field_schema in indexes_to_create: 

179 try: 

180 client.create_payload_index( 

181 collection_name=self.collection_name, 

182 field_name=field_name, 

183 field_schema=field_schema, # type: ignore 

184 ) 

185 created_indexes.append(field_name) 

186 self.logger.debug(f"Created payload index for field: {field_name}") 

187 except Exception as e: 

188 failed_indexes.append((field_name, str(e))) 

189 self.logger.warning( 

190 f"Failed to create index for {field_name}", error=str(e) 

191 ) 

192 

193 # Log index creation summary 

194 self.logger.info( 

195 f"Collection {self.collection_name} created with indexes", 

196 created_indexes=created_indexes, 

197 failed_indexes=( 

198 [name for name, _ in failed_indexes] if failed_indexes else None 

199 ), 

200 total_indexes_created=len(created_indexes), 

201 ) 

202 

203 if failed_indexes: 

204 self.logger.warning( 

205 "Some indexes failed to create but collection is functional", 

206 failed_details=failed_indexes, 

207 ) 

208 except Exception as e: 

209 self.logger.error("Failed to create collection", error=str(e)) 

210 raise 

211 

212 async def upsert_points(self, points: list[models.PointStruct]) -> None: 

213 """Upsert points into the collection. 

214 

215 Args: 

216 points: List of points to upsert 

217 """ 

218 self.logger.debug( 

219 "Upserting points", 

220 extra={"point_count": len(points), "collection": self.collection_name}, 

221 ) 

222 

223 try: 

224 client = self._ensure_client_connected() 

225 await asyncio.to_thread( 

226 client.upsert, collection_name=self.collection_name, points=points 

227 ) 

228 self.logger.debug( 

229 "Successfully upserted points", 

230 extra={"point_count": len(points), "collection": self.collection_name}, 

231 ) 

232 except Exception as e: 

233 self.logger.error( 

234 "Failed to upsert points", 

235 extra={ 

236 "error": str(e), 

237 "point_count": len(points), 

238 "collection": self.collection_name, 

239 }, 

240 ) 

241 raise 

242 

243 def search( 

244 self, query_vector: list[float], limit: int = 5 

245 ) -> list[models.ScoredPoint]: 

246 """Search for similar vectors in the collection.""" 

247 try: 

248 client = self._ensure_client_connected() 

249 search_result = client.search( 

250 collection_name=self.collection_name, 

251 query_vector=query_vector, 

252 limit=limit, 

253 ) 

254 return search_result 

255 except Exception as e: 

256 logger.error("Failed to search collection", error=str(e)) 

257 raise 

258 

259 def search_with_project_filter( 

260 self, query_vector: list[float], project_ids: list[str], limit: int = 5 

261 ) -> list[models.ScoredPoint]: 

262 """Search for similar vectors in the collection with project filtering. 

263 

264 Args: 

265 query_vector: Query vector for similarity search 

266 project_ids: List of project IDs to filter by 

267 limit: Maximum number of results to return 

268 

269 Returns: 

270 List of scored points matching the query and project filter 

271 """ 

272 try: 

273 client = self._ensure_client_connected() 

274 

275 # Build project filter 

276 project_filter = models.Filter( 

277 must=[ 

278 models.FieldCondition( 

279 key="project_id", match=models.MatchAny(any=project_ids) 

280 ) 

281 ] 

282 ) 

283 

284 search_result = client.search( 

285 collection_name=self.collection_name, 

286 query_vector=query_vector, 

287 query_filter=project_filter, 

288 limit=limit, 

289 ) 

290 return search_result 

291 except Exception as e: 

292 logger.error( 

293 "Failed to search collection with project filter", 

294 error=str(e), 

295 project_ids=project_ids, 

296 ) 

297 raise 

298 

299 def get_project_collections(self) -> dict[str, str]: 

300 """Get mapping of project IDs to their collection names. 

301 

302 Returns: 

303 Dictionary mapping project_id to collection_name 

304 """ 

305 try: 

306 client = self._ensure_client_connected() 

307 

308 # Scroll through all points to get unique project-collection mappings 

309 scroll_result = client.scroll( 

310 collection_name=self.collection_name, 

311 limit=10000, # Large limit to get all unique projects 

312 with_payload=True, 

313 with_vectors=False, 

314 ) 

315 

316 project_collections = {} 

317 for point in scroll_result[0]: 

318 if point.payload: 

319 project_id = point.payload.get("project_id") 

320 collection_name = point.payload.get("collection_name") 

321 if project_id and collection_name: 

322 project_collections[project_id] = collection_name 

323 

324 return project_collections 

325 except Exception as e: 

326 logger.error("Failed to get project collections", error=str(e)) 

327 raise 

328 

329 def delete_collection(self) -> None: 

330 """Delete the collection.""" 

331 try: 

332 client = self._ensure_client_connected() 

333 client.delete_collection(collection_name=self.collection_name) 

334 logger.debug("Collection deleted", collection=self.collection_name) 

335 except Exception as e: 

336 logger.error("Failed to delete collection", error=str(e)) 

337 raise 

338 

339 async def delete_points_by_document_id(self, document_ids: list[str]) -> None: 

340 """Delete points from the collection by document ID. 

341 

342 Args: 

343 document_ids: List of document IDs to delete 

344 """ 

345 self.logger.debug( 

346 "Deleting points by document ID", 

347 extra={ 

348 "document_count": len(document_ids), 

349 "collection": self.collection_name, 

350 }, 

351 ) 

352 

353 try: 

354 client = self._ensure_client_connected() 

355 await asyncio.to_thread( 

356 client.delete, 

357 collection_name=self.collection_name, 

358 points_selector=models.Filter( 

359 must=[ 

360 models.FieldCondition( 

361 key="document_id", match=models.MatchAny(any=document_ids) 

362 ) 

363 ] 

364 ), 

365 ) 

366 self.logger.debug( 

367 "Successfully deleted points", 

368 extra={ 

369 "document_count": len(document_ids), 

370 "collection": self.collection_name, 

371 }, 

372 ) 

373 except Exception as e: 

374 self.logger.error( 

375 "Failed to delete points", 

376 extra={ 

377 "error": str(e), 

378 "document_count": len(document_ids), 

379 "collection": self.collection_name, 

380 }, 

381 ) 

382 raise