Coverage for src / qdrant_loader / core / qdrant_manager.py: 71%

221 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:38 +0000

1import asyncio 

2from typing import Any, cast 

3from urllib.parse import urlparse 

4 

5from qdrant_client import QdrantClient 

6from qdrant_client.http import models 

7from qdrant_client.http.models import ( 

8 Distance, 

9 VectorParams, 

10) 

11from qdrant_loader_core.config import ( 

12 CollectionVectorCapabilities, 

13 SparseRuntimeConfig, 

14 parse_collection_capabilities, 

15) 

16from qdrant_loader_core.sparse import get_sparse_encoder 

17 

18from ..config import Settings, get_global_config, get_settings 

19from ..utils.logging import LoggingConfig 

20 

21logger = LoggingConfig.get_logger(__name__) 

22 

23 

24class QdrantConnectionError(Exception): 

25 """Custom exception for Qdrant connection errors.""" 

26 

27 def __init__( 

28 self, message: str, original_error: str | None = None, url: str | None = None 

29 ): 

30 self.message = message 

31 self.original_error = original_error 

32 self.url = url 

33 super().__init__(self.message) 

34 

35 

36class QdrantManager: 

37 def __init__(self, settings: Settings | None = None): 

38 """Initialize the qDrant manager. 

39 

40 Args: 

41 settings: The application settings 

42 """ 

43 self.settings = settings or get_settings() 

44 self.client = None 

45 self.collection_name = self.settings.qdrant_collection_name 

46 self.logger = LoggingConfig.get_logger(__name__) 

47 self.batch_size = get_global_config().embedding.batch_size 

48 self.sparse_runtime = self._resolve_sparse_runtime_config() 

49 self._collection_vector_capabilities: CollectionVectorCapabilities | None = None 

50 self._sparse_fallback_warning_emitted = False 

51 self.connect() 

52 

53 def _is_api_key_present(self) -> bool: 

54 """ 

55 Check if a valid API key is present. 

56 Returns True if the API key is a non-empty string that is not 'None' or 'null'. 

57 """ 

58 api_key = self.settings.qdrant_api_key 

59 if not api_key: # Catches None, empty string, etc. 

60 return False 

61 return api_key.lower() not in ["none", "null"] 

62 

63 def _resolve_sparse_runtime_config(self) -> SparseRuntimeConfig: 

64 try: 

65 llm = getattr(get_global_config(), "llm", None) or {} 

66 except Exception as e: 

67 self.logger.warning( 

68 "Failed to read global LLM config for sparse runtime; using defaults", 

69 error=str(e), 

70 exc_info=True, 

71 ) 

72 llm = {} 

73 global_config = {"llm": llm} if isinstance(llm, dict) else {} 

74 return SparseRuntimeConfig.from_global_config(global_config) 

75 

76 def _get_collection_vector_capabilities(self) -> CollectionVectorCapabilities: 

77 if self._collection_vector_capabilities is not None: 

78 return self._collection_vector_capabilities 

79 

80 client = self._ensure_client_connected() 

81 try: 

82 info = client.get_collection(collection_name=self.collection_name) 

83 except Exception as e: 

84 # Don't cache: a transient outage would otherwise pin every 

85 # subsequent upsert to dense-only payload shape even after Qdrant 

86 # becomes reachable again, mismatching a hybrid collection schema. 

87 self.logger.warning( 

88 "Failed to inspect Qdrant collection schema; assuming dense-only", 

89 collection=self.collection_name, 

90 error=str(e), 

91 ) 

92 return CollectionVectorCapabilities() 

93 

94 self._collection_vector_capabilities = parse_collection_capabilities( 

95 info, self.sparse_runtime 

96 ) 

97 return self._collection_vector_capabilities 

98 

99 def _dense_query_using(self) -> str | None: 

100 caps = self._get_collection_vector_capabilities() 

101 if caps.has_named_dense: 

102 return self.sparse_runtime.dense_vector_name 

103 return None 

104 

105 def _sparse_upsert_enabled(self) -> bool: 

106 if not self.sparse_runtime.enabled: 

107 return False 

108 

109 caps = self._get_collection_vector_capabilities() 

110 if caps.has_named_dense and caps.has_sparse: 

111 return True 

112 

113 if not self._sparse_fallback_warning_emitted: 

114 self.logger.warning( 

115 "Sparse vectors requested but collection schema does not support them; falling back to dense-only upserts", 

116 collection=self.collection_name, 

117 dense_vector_name=self.sparse_runtime.dense_vector_name, 

118 sparse_vector_name=self.sparse_runtime.sparse_vector_name, 

119 ) 

120 self._sparse_fallback_warning_emitted = True 

121 return False 

122 

123 def build_point_vector(self, dense_embedding: list[float], text: str) -> object: 

124 """Build the point vector payload for upsert. 

125 

126 Three shapes are possible depending on the live collection schema: 

127 - dense+sparse named dict (hybrid-ready collection), 

128 - dense-only named dict (legacy named-vector collection), 

129 - raw dense list (legacy unnamed collection). 

130 """ 

131 if self._sparse_upsert_enabled(): 

132 return self._build_hybrid_payload(dense_embedding, text) 

133 return self._build_dense_payload(dense_embedding) 

134 

135 def _build_dense_payload(self, dense_embedding: list[float]) -> object: 

136 """Return dense-only payload using the named-vector shape if the collection requires it.""" 

137 if self._dense_query_using() is not None: 

138 return {self.sparse_runtime.dense_vector_name: dense_embedding} 

139 return dense_embedding 

140 

141 def _build_hybrid_payload(self, dense_embedding: list[float], text: str) -> object: 

142 """Return dense+sparse payload, with a dense-only fallback on encode failure.""" 

143 try: 

144 sparse = get_sparse_encoder(self.sparse_runtime.model).encode_document(text) 

145 except Exception as e: 

146 self.logger.warning( 

147 "Failed to generate sparse vectors; falling back to dense-only upsert", 

148 error=str(e), 

149 ) 

150 return self._build_dense_payload(dense_embedding) 

151 

152 if sparse.is_empty(): 

153 return {self.sparse_runtime.dense_vector_name: dense_embedding} 

154 return { 

155 self.sparse_runtime.dense_vector_name: dense_embedding, 

156 self.sparse_runtime.sparse_vector_name: models.SparseVector( 

157 indices=sparse.indices, values=sparse.values 

158 ), 

159 } 

160 

161 def connect(self) -> None: 

162 """Establish connection to qDrant server.""" 

163 try: 

164 # Ensure HTTPS is used when API key is present, but only for non-local URLs 

165 url = self.settings.qdrant_url 

166 api_key = ( 

167 self.settings.qdrant_api_key if self._is_api_key_present() else None 

168 ) 

169 

170 if api_key: 

171 parsed_url = urlparse(url) 

172 # Only force HTTPS for non-local URLs 

173 if parsed_url.scheme != "https" and not any( 

174 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"] 

175 ): 

176 url = url.replace("http://", "https://", 1) 

177 self.logger.warning("Forcing HTTPS connection due to API key usage") 

178 

179 try: 

180 self.client = QdrantClient( 

181 url=url, 

182 api_key=api_key, 

183 timeout=60, # 60 seconds timeout 

184 ) 

185 self.logger.debug("Successfully connected to qDrant") 

186 except Exception as e: 

187 raise QdrantConnectionError( 

188 "Failed to connect to qDrant: Connection error", 

189 original_error=str(e), 

190 url=url, 

191 ) from e 

192 

193 except Exception as e: 

194 raise QdrantConnectionError( 

195 "Failed to connect to qDrant: Unexpected error", 

196 original_error=str(e), 

197 url=url, 

198 ) from e 

199 

200 def _ensure_client_connected(self) -> QdrantClient: 

201 """Ensure the client is connected before performing operations.""" 

202 if self.client is None: 

203 raise QdrantConnectionError( 

204 "Qdrant client is not connected. Please call connect() first." 

205 ) 

206 return cast(QdrantClient, self.client) 

207 

208 def create_collection(self) -> None: 

209 """Create a new collection if it doesn't exist.""" 

210 try: 

211 client = self._ensure_client_connected() 

212 # Check if collection already exists 

213 collections = client.get_collections() 

214 if any(c.name == self.collection_name for c in collections.collections): 

215 self.logger.info(f"Collection {self.collection_name} already exists") 

216 return 

217 

218 # Get vector size from unified LLM settings first, then legacy embedding 

219 vector_size: int | None = None 

220 try: 

221 global_cfg = get_global_config() 

222 llm_settings = getattr(global_cfg, "llm", None) 

223 if llm_settings is not None: 

224 embeddings_cfg = getattr(llm_settings, "embeddings", None) 

225 vs = ( 

226 getattr(embeddings_cfg, "vector_size", None) 

227 if embeddings_cfg is not None 

228 else None 

229 ) 

230 if isinstance(vs, int): 

231 vector_size = int(vs) 

232 except Exception: 

233 vector_size = None 

234 

235 if vector_size is None: 

236 try: 

237 legacy_vs = get_global_config().embedding.vector_size 

238 if isinstance(legacy_vs, int): 

239 vector_size = int(legacy_vs) 

240 except Exception: 

241 vector_size = None 

242 

243 if vector_size is None: 

244 self.logger.warning( 

245 "No vector_size specified in config; falling back to 1024 (deprecated default). Set global.llm.embeddings.vector_size." 

246 ) 

247 vector_size = 1024 

248 

249 # sparse.enabled is a strict declaration. If True, the collection 

250 # is created with a sparse vector; failures propagate. If False, 

251 # dense-only. Operators on Qdrant servers that don't support sparse 

252 # vectors must set sparse.enabled=false explicitly. 

253 dense_params = VectorParams(size=vector_size, distance=Distance.COSINE) 

254 if self.sparse_runtime.enabled: 

255 client.create_collection( 

256 collection_name=self.collection_name, 

257 vectors_config={ 

258 self.sparse_runtime.dense_vector_name: dense_params 

259 }, 

260 sparse_vectors_config={ 

261 self.sparse_runtime.sparse_vector_name: models.SparseVectorParams() 

262 }, 

263 ) 

264 self.logger.info( 

265 "Created Qdrant collection with dense+sparse vectors", 

266 collection=self.collection_name, 

267 dense_vector_name=self.sparse_runtime.dense_vector_name, 

268 sparse_vector_name=self.sparse_runtime.sparse_vector_name, 

269 sparse_model=self.sparse_runtime.model, 

270 ) 

271 else: 

272 client.create_collection( 

273 collection_name=self.collection_name, 

274 vectors_config=dense_params, 

275 ) 

276 

277 self._collection_vector_capabilities = CollectionVectorCapabilities( 

278 has_named_dense=self.sparse_runtime.enabled, 

279 has_sparse=self.sparse_runtime.enabled, 

280 ) 

281 

282 # Create payload indexes for optimal search performance 

283 indexes_to_create = [ 

284 # Essential performance indexes 

285 ( 

286 "document_id", 

287 {"type": "keyword"}, 

288 ), # Existing index, kept for backward compatibility 

289 ( 

290 "project_id", 

291 {"type": "keyword"}, 

292 ), # Critical for multi-tenant filtering 

293 ("source_type", {"type": "keyword"}), # Document type filtering 

294 ("source", {"type": "keyword"}), # Source path filtering 

295 ("title", {"type": "keyword"}), # Title-based search and filtering 

296 ("created_at", {"type": "keyword"}), # Temporal filtering 

297 ("updated_at", {"type": "keyword"}), # Temporal filtering 

298 # Secondary performance indexes 

299 ("is_attachment", {"type": "bool"}), # Attachment filtering 

300 ( 

301 "parent_document_id", 

302 {"type": "keyword"}, 

303 ), # Hierarchical relationships 

304 ("original_file_type", {"type": "keyword"}), # File type filtering 

305 ("is_converted", {"type": "bool"}), # Conversion status filtering 

306 ] 

307 

308 # Create indexes with proper error handling 

309 created_indexes = [] 

310 failed_indexes = [] 

311 

312 for field_name, field_schema in indexes_to_create: 

313 try: 

314 client.create_payload_index( 

315 collection_name=self.collection_name, 

316 field_name=field_name, 

317 field_schema=field_schema, # type: ignore 

318 ) 

319 created_indexes.append(field_name) 

320 self.logger.debug(f"Created payload index for field: {field_name}") 

321 except Exception as e: 

322 failed_indexes.append((field_name, str(e))) 

323 self.logger.warning( 

324 f"Failed to create index for {field_name}", error=str(e) 

325 ) 

326 

327 # Log index creation summary 

328 self.logger.info( 

329 f"Collection {self.collection_name} created with indexes", 

330 created_indexes=created_indexes, 

331 failed_indexes=( 

332 [name for name, _ in failed_indexes] if failed_indexes else None 

333 ), 

334 total_indexes_created=len(created_indexes), 

335 ) 

336 

337 if failed_indexes: 

338 self.logger.warning( 

339 "Some indexes failed to create but collection is functional", 

340 failed_details=failed_indexes, 

341 ) 

342 except Exception as e: 

343 self.logger.error("Failed to create collection", error=str(e)) 

344 raise 

345 

346 async def upsert_points(self, points: list[models.PointStruct]) -> None: 

347 """Upsert points into the collection. 

348 

349 Args: 

350 points: List of points to upsert 

351 """ 

352 self.logger.debug( 

353 "Upserting points", 

354 extra={"point_count": len(points), "collection": self.collection_name}, 

355 ) 

356 

357 try: 

358 client = self._ensure_client_connected() 

359 await asyncio.to_thread( 

360 client.upsert, collection_name=self.collection_name, points=points 

361 ) 

362 self.logger.debug( 

363 "Successfully upserted points", 

364 extra={"point_count": len(points), "collection": self.collection_name}, 

365 ) 

366 except Exception as e: 

367 self.logger.error( 

368 "Failed to upsert points", 

369 extra={ 

370 "error": str(e), 

371 "point_count": len(points), 

372 "collection": self.collection_name, 

373 }, 

374 ) 

375 raise 

376 

377 def search( 

378 self, query_vector: list[float], limit: int = 5 

379 ) -> list[models.ScoredPoint]: 

380 """Search for similar vectors in the collection.""" 

381 try: 

382 client = self._ensure_client_connected() 

383 query_kwargs: dict[str, Any] = { 

384 "collection_name": self.collection_name, 

385 "query": query_vector, 

386 "limit": limit, 

387 } 

388 using = self._dense_query_using() 

389 if using: 

390 query_kwargs["using"] = using 

391 # Use query_points API (qdrant-client 1.10+) 

392 query_response = client.query_points(**query_kwargs) 

393 return query_response.points 

394 except Exception as e: 

395 logger.error("Failed to search collection", error=str(e)) 

396 raise 

397 

398 def search_with_project_filter( 

399 self, query_vector: list[float], project_ids: list[str], limit: int = 5 

400 ) -> list[models.ScoredPoint]: 

401 """Search for similar vectors in the collection with project filtering. 

402 

403 Args: 

404 query_vector: Query vector for similarity search 

405 project_ids: List of project IDs to filter by 

406 limit: Maximum number of results to return 

407 

408 Returns: 

409 List of scored points matching the query and project filter 

410 """ 

411 try: 

412 client = self._ensure_client_connected() 

413 

414 # Build project filter 

415 project_filter = models.Filter( 

416 must=[ 

417 models.FieldCondition( 

418 key="project_id", match=models.MatchAny(any=project_ids) 

419 ) 

420 ] 

421 ) 

422 

423 query_kwargs: dict[str, Any] = { 

424 "collection_name": self.collection_name, 

425 "query": query_vector, 

426 "query_filter": project_filter, 

427 "limit": limit, 

428 } 

429 using = self._dense_query_using() 

430 if using: 

431 query_kwargs["using"] = using 

432 # Use query_points API (qdrant-client 1.10+) 

433 query_response = client.query_points(**query_kwargs) 

434 return query_response.points 

435 except Exception as e: 

436 logger.error( 

437 "Failed to search collection with project filter", 

438 error=str(e), 

439 project_ids=project_ids, 

440 ) 

441 raise 

442 

443 def get_project_collections(self) -> dict[str, str]: 

444 """Get mapping of project IDs to their collection names. 

445 

446 Returns: 

447 Dictionary mapping project_id to collection_name 

448 """ 

449 try: 

450 client = self._ensure_client_connected() 

451 

452 # Scroll through all points to get unique project-collection mappings 

453 scroll_result = client.scroll( 

454 collection_name=self.collection_name, 

455 limit=10000, # Large limit to get all unique projects 

456 with_payload=True, 

457 with_vectors=False, 

458 ) 

459 

460 project_collections = {} 

461 for point in scroll_result[0]: 

462 if point.payload: 

463 project_id = point.payload.get("project_id") 

464 collection_name = point.payload.get("collection_name") 

465 if project_id and collection_name: 

466 project_collections[project_id] = collection_name 

467 

468 return project_collections 

469 except Exception as e: 

470 logger.error("Failed to get project collections", error=str(e)) 

471 raise 

472 

473 def delete_collection(self) -> None: 

474 """Delete the collection.""" 

475 try: 

476 client = self._ensure_client_connected() 

477 client.delete_collection(collection_name=self.collection_name) 

478 logger.debug("Collection deleted", collection=self.collection_name) 

479 except Exception as e: 

480 logger.error("Failed to delete collection", error=str(e)) 

481 raise 

482 

483 async def delete_points_by_document_id(self, document_ids: list[str]) -> None: 

484 """Delete points from the collection by document ID. 

485 

486 Args: 

487 document_ids: List of document IDs to delete 

488 """ 

489 self.logger.debug( 

490 "Deleting points by document ID", 

491 extra={ 

492 "document_count": len(document_ids), 

493 "collection": self.collection_name, 

494 }, 

495 ) 

496 

497 try: 

498 client = self._ensure_client_connected() 

499 await asyncio.to_thread( 

500 client.delete, 

501 collection_name=self.collection_name, 

502 points_selector=models.Filter( 

503 must=[ 

504 models.FieldCondition( 

505 key="document_id", match=models.MatchAny(any=document_ids) 

506 ) 

507 ] 

508 ), 

509 ) 

510 self.logger.debug( 

511 "Successfully deleted points", 

512 extra={ 

513 "document_count": len(document_ids), 

514 "collection": self.collection_name, 

515 }, 

516 ) 

517 except Exception as e: 

518 self.logger.error( 

519 "Failed to delete points", 

520 extra={ 

521 "error": str(e), 

522 "document_count": len(document_ids), 

523 "collection": self.collection_name, 

524 }, 

525 ) 

526 raise