Coverage for src / qdrant_loader_mcp_server / search / components / vector_search_service.py: 95%

183 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:38 +0000

1"""Vector search service for hybrid search.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import hashlib 

7import time 

8from asyncio import Lock 

9from contextvars import ContextVar 

10from dataclasses import dataclass 

11from typing import TYPE_CHECKING, Any 

12 

13if TYPE_CHECKING: 

14 from qdrant_client import AsyncQdrantClient 

15 from qdrant_client.http import models as qdrant_models 

16 

17from qdrant_client.http import models 

18from qdrant_loader_core.config import ( 

19 CollectionVectorCapabilities, 

20 parse_collection_capabilities, 

21) 

22from qdrant_loader_core.sparse import get_sparse_encoder 

23 

24from ...utils.logging import LoggingConfig 

25from ..sparse_config import load_sparse_runtime_config 

26from .field_query_parser import FieldQueryParser 

27 

28# Task-local flag set by vector_search when Qdrant fusion is used for the 

29# current query. ContextVar isolates concurrent searches that share the same 

30# VectorSearchService instance. 

31_used_qdrant_hybrid_ctx: ContextVar[bool] = ContextVar( 

32 "vector_search_used_qdrant_hybrid", default=False 

33) 

34 

35 

36@dataclass 

37class FilterResult: 

38 score: float 

39 payload: dict 

40 

41 

42class VectorSearchService: 

43 """Handles vector search operations using Qdrant.""" 

44 

45 def __init__( 

46 self, 

47 qdrant_client: AsyncQdrantClient, 

48 collection_name: str, 

49 min_score: float = 0.3, 

50 cache_enabled: bool = True, 

51 cache_ttl: int = 300, 

52 cache_max_size: int = 500, 

53 hnsw_ef: int = 128, 

54 use_exact_search: bool = False, 

55 *, 

56 embeddings_provider: Any | None = None, 

57 openai_client: Any | None = None, 

58 embedding_model: str = "text-embedding-3-small", 

59 ): 

60 """Initialize the vector search service. 

61 

62 Args: 

63 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient) 

64 openai_client: OpenAI client instance 

65 collection_name: Name of the Qdrant collection 

66 min_score: Minimum score threshold 

67 cache_enabled: Whether to enable search result caching 

68 cache_ttl: Cache time-to-live in seconds 

69 cache_max_size: Maximum number of cached results 

70 """ 

71 self.qdrant_client = qdrant_client 

72 self.embeddings_provider = embeddings_provider 

73 self.openai_client = openai_client 

74 self.collection_name = collection_name 

75 self.embedding_model = embedding_model 

76 self.min_score = min_score 

77 

78 # Search result caching configuration 

79 self.cache_enabled = cache_enabled 

80 self.cache_ttl = cache_ttl 

81 self.cache_max_size = cache_max_size 

82 self._search_cache: dict[str, dict[str, Any]] = {} 

83 self._cache_lock: Lock = Lock() 

84 

85 # Cache performance metrics 

86 self._cache_hits = 0 

87 self._cache_misses = 0 

88 

89 # Field query parser for handling field:value syntax 

90 self.field_parser = FieldQueryParser() 

91 

92 self.logger = LoggingConfig.get_logger(__name__) 

93 

94 self.sparse_runtime = load_sparse_runtime_config() 

95 self._collection_capabilities: CollectionVectorCapabilities | None = None 

96 self._capabilities_lock: Lock = Lock() 

97 

98 # Qdrant search parameters 

99 self.hnsw_ef = hnsw_ef 

100 self.use_exact_search = use_exact_search 

101 

102 def _generate_cache_key( 

103 self, query: str, limit: int, project_ids: list[str] | None = None 

104 ) -> str: 

105 """Generate a cache key for search parameters. 

106 

107 Args: 

108 query: Search query 

109 limit: Maximum number of results 

110 project_ids: Optional project ID filters 

111 

112 Returns: 

113 SHA256 hash of search parameters for cache key 

114 """ 

115 # Create a deterministic string from search parameters 

116 project_str = ",".join(sorted(project_ids)) if project_ids else "none" 

117 cache_input = ( 

118 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}" 

119 ) 

120 return hashlib.sha256(cache_input.encode()).hexdigest() 

121 

122 def _cleanup_expired_cache(self) -> None: 

123 """Remove expired entries from cache.""" 

124 if not self.cache_enabled: 

125 return 

126 

127 current_time = time.time() 

128 expired_keys = [ 

129 key 

130 for key, value in self._search_cache.items() 

131 if current_time - value["timestamp"] > self.cache_ttl 

132 ] 

133 

134 for key in expired_keys: 

135 del self._search_cache[key] 

136 

137 # Also enforce max size limit 

138 if len(self._search_cache) > self.cache_max_size: 

139 # Remove oldest entries (simple FIFO eviction) 

140 sorted_items = sorted( 

141 self._search_cache.items(), key=lambda x: x[1]["timestamp"] 

142 ) 

143 items_to_remove = len(self._search_cache) - self.cache_max_size 

144 for key, _ in sorted_items[:items_to_remove]: 

145 del self._search_cache[key] 

146 

147 async def _get_collection_capabilities(self) -> CollectionVectorCapabilities: 

148 """Probe the collection for named-dense and sparse vector support. 

149 

150 Transient ``get_collection`` failures are logged and returned as 

151 empty capabilities for the current call only — they are never cached, 

152 so the next call retries. 

153 """ 

154 if self._collection_capabilities is not None: 

155 return self._collection_capabilities 

156 

157 async with self._capabilities_lock: 

158 if self._collection_capabilities is not None: 

159 return self._collection_capabilities 

160 try: 

161 info = await self.qdrant_client.get_collection( 

162 collection_name=self.collection_name 

163 ) 

164 except Exception as e: 

165 # Don't cache: a transient outage would otherwise pin every 

166 # subsequent query to the dense fallback path for this 

167 # service's lifetime. 

168 self.logger.warning( 

169 "Failed to inspect Qdrant collection schema; assuming dense-only", 

170 collection=self.collection_name, 

171 error=str(e), 

172 ) 

173 return CollectionVectorCapabilities() 

174 self._collection_capabilities = parse_collection_capabilities( 

175 info, self.sparse_runtime 

176 ) 

177 return self._collection_capabilities 

178 

179 def _dense_using(self, caps: CollectionVectorCapabilities) -> str | None: 

180 """Return the named-dense vector key, or None for legacy unnamed collections.""" 

181 return self.sparse_runtime.dense_vector_name if caps.has_named_dense else None 

182 

183 def _hybrid_query_active(self, caps: CollectionVectorCapabilities) -> bool: 

184 """Combine collection schema with runtime config to decide on server-side fusion.""" 

185 return ( 

186 caps.hybrid_ready 

187 and self.sparse_runtime.enabled 

188 and self.sparse_runtime.use_qdrant_hybrid 

189 ) 

190 

191 async def supports_qdrant_hybrid(self) -> bool: 

192 """Return True if the collection is configured for Qdrant dense+sparse fusion. 

193 

194 Used by HybridPipeline to decide whether to skip the separate keyword 

195 search; calling this lets the pipeline preserve parallelism in the 

196 dense-only path. 

197 """ 

198 caps = await self._get_collection_capabilities() 

199 return self._hybrid_query_active(caps) 

200 

201 def _encode_sparse_query(self, text: str): 

202 sparse = get_sparse_encoder(self.sparse_runtime.model).encode_query(text) 

203 if sparse.is_empty(): 

204 return None 

205 return models.SparseVector(indices=sparse.indices, values=sparse.values) 

206 

207 def used_qdrant_hybrid_last_query(self) -> bool: 

208 """Return whether the most recent vector_search in this task used Qdrant fusion. 

209 

210 Backed by a ContextVar so concurrent searches across tasks (e.g. multiple 

211 MCP requests sharing this service) don't clobber each other. 

212 """ 

213 return _used_qdrant_hybrid_ctx.get() 

214 

215 async def get_embedding(self, text: str) -> list[float]: 

216 """Get embedding for text using OpenAI client when available, else provider. 

217 

218 Args: 

219 text: Text to get embedding for 

220 

221 Returns: 

222 List of embedding values 

223 

224 Raises: 

225 Exception: If embedding generation fails 

226 """ 

227 # Prefer provider when available 

228 if self.embeddings_provider is not None: 

229 # Accept either a provider (with .embeddings()) or a direct embeddings client 

230 client = ( 

231 self.embeddings_provider.embeddings() 

232 if hasattr(self.embeddings_provider, "embeddings") 

233 else self.embeddings_provider 

234 ) 

235 for _ in range(3): 

236 try: 

237 vectors = await client.embed([text]) 

238 return vectors[0] 

239 except Exception as e: 

240 self.logger.warning( 

241 "Provider embedding failed, retrying...", error=str(e) 

242 ) 

243 await asyncio.sleep(0.5) 

244 

245 # Fallback to OpenAI (to keep backward compatibility & pass tests) 

246 if self.openai_client is not None: 

247 try: 

248 response = await self.openai_client.embeddings.create( 

249 model=self.embedding_model, 

250 input=text, 

251 ) 

252 return response.data[0].embedding 

253 except Exception as e: 

254 self.logger.error("OpenAI fallback failed", error=str(e)) 

255 raise 

256 

257 # Nothing configured 

258 raise RuntimeError("No embeddings provider or OpenAI client configured") 

259 

260 async def vector_search( 

261 self, query: str, limit: int, project_ids: list[str] | None = None 

262 ) -> list[dict[str, Any]]: 

263 """Perform vector search using Qdrant with caching support. 

264 

265 Args: 

266 query: Search query 

267 limit: Maximum number of results 

268 project_ids: Optional project ID filters 

269 

270 Returns: 

271 List of search results with scores, text, metadata, and source_type 

272 """ 

273 # Reset before the cache lookup so used_qdrant_hybrid_last_query() 

274 # reflects this call rather than the previous task-local value, even 

275 # when we return a cached result. 

276 _used_qdrant_hybrid_ctx.set(False) 

277 cache_key = self._generate_cache_key(query, limit, project_ids) 

278 cached = await self._cache_get_if_valid(cache_key, query) 

279 if cached is not None: 

280 return cached 

281 

282 parsed_query = self.field_parser.parse_query(query) 

283 self.logger.debug( 

284 f"Parsed query: {len(parsed_query.field_queries)} field queries, " 

285 f"text: '{parsed_query.text_query}'" 

286 ) 

287 

288 if self.field_parser.should_use_filter_only(parsed_query): 

289 results = await self._run_filter_only_search( 

290 parsed_query, project_ids, limit 

291 ) 

292 else: 

293 results = await self._run_vector_search( 

294 parsed_query, project_ids, query, limit 

295 ) 

296 

297 extracted = self._extract_hits(results) 

298 await self._cache_put(cache_key, extracted, query) 

299 return extracted 

300 

301 async def _cache_get_if_valid( 

302 self, cache_key: str, query: str 

303 ) -> list[dict[str, Any]] | None: 

304 """Return cached results if present and not expired; otherwise increment the miss counter.""" 

305 if self.cache_enabled: 

306 async with self._cache_lock: 

307 self._cleanup_expired_cache() 

308 cached_entry = self._search_cache.get(cache_key) 

309 if ( 

310 cached_entry is not None 

311 and time.time() - cached_entry["timestamp"] <= self.cache_ttl 

312 ): 

313 self._cache_hits += 1 

314 self.logger.debug( 

315 "Search cache hit", 

316 query=query[:50], 

317 cache_hits=self._cache_hits, 

318 cache_misses=self._cache_misses, 

319 hit_rate=( 

320 f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%" 

321 ), 

322 ) 

323 return cached_entry["results"] 

324 

325 self._cache_misses += 1 

326 self.logger.debug( 

327 "Search cache miss - performing QDrant search", 

328 query=query[:50], 

329 cache_hits=self._cache_hits, 

330 cache_misses=self._cache_misses, 

331 ) 

332 return None 

333 

334 async def _cache_put( 

335 self, cache_key: str, results: list[dict[str, Any]], query: str 

336 ) -> None: 

337 """Store ``results`` under ``cache_key`` when caching is enabled.""" 

338 if not self.cache_enabled: 

339 return 

340 async with self._cache_lock: 

341 self._search_cache[cache_key] = { 

342 "results": results, 

343 "timestamp": time.time(), 

344 } 

345 self.logger.debug( 

346 "Cached search results", 

347 query=query[:50], 

348 results_count=len(results), 

349 cache_size=len(self._search_cache), 

350 ) 

351 

352 async def _run_filter_only_search( 

353 self, 

354 parsed_query, 

355 project_ids: list[str] | None, 

356 limit: int, 

357 ) -> list[FilterResult]: 

358 """Scroll-based exact-match path used when the query contains only field filters.""" 

359 self.logger.debug("Using filter-only search for exact field matching") 

360 query_filter = self.field_parser.create_qdrant_filter( 

361 parsed_query.field_queries, project_ids 

362 ) 

363 scroll_results = await self.qdrant_client.scroll( 

364 collection_name=self.collection_name, 

365 limit=limit, 

366 scroll_filter=query_filter, 

367 with_payload=True, 

368 with_vectors=False, 

369 ) 

370 return [FilterResult(1.0, point.payload) for point in scroll_results[0]] 

371 

372 async def _run_vector_search( 

373 self, 

374 parsed_query, 

375 project_ids: list[str] | None, 

376 original_query: str, 

377 limit: int, 

378 ) -> list: 

379 """Dispatch to either the Qdrant hybrid query or the dense-only query.""" 

380 search_query = parsed_query.text_query or original_query 

381 query_embedding = await self.get_embedding(search_query) 

382 search_params = models.SearchParams( 

383 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search) 

384 ) 

385 query_filter = self.field_parser.create_qdrant_filter( 

386 parsed_query.field_queries, project_ids 

387 ) 

388 # Routing is fully determined by the probed collection schema and the 

389 # runtime config — no silent fallback. If the collection supports 

390 # hybrid retrieval, we use hybrid; otherwise dense. Hybrid query 

391 # failures propagate so operators see them instead of degraded results. 

392 caps = await self._get_collection_capabilities() 

393 if self._hybrid_query_active(caps): 

394 return await self._run_qdrant_hybrid_query( 

395 query_embedding, 

396 search_query, 

397 query_filter, 

398 search_params, 

399 caps, 

400 limit, 

401 ) 

402 return await self._run_dense_query( 

403 query_embedding, query_filter, search_params, caps, limit 

404 ) 

405 

406 async def _run_qdrant_hybrid_query( 

407 self, 

408 query_embedding: list[float], 

409 search_query: str, 

410 query_filter, 

411 search_params, 

412 caps: CollectionVectorCapabilities, 

413 limit: int, 

414 ) -> list: 

415 """Issue a server-side RRF fusion query over the dense + sparse prefetches.""" 

416 sparse_query = self._encode_sparse_query(search_query) 

417 if sparse_query is None: 

418 raise ValueError("Sparse query generation returned empty vector") 

419 

420 prefetch_limit = limit * 3 

421 query_response = await self.qdrant_client.query_points( 

422 collection_name=self.collection_name, 

423 prefetch=[ 

424 models.Prefetch( 

425 query=query_embedding, 

426 using=self._dense_using(caps), 

427 filter=query_filter, 

428 params=search_params, 

429 limit=prefetch_limit, 

430 ), 

431 models.Prefetch( 

432 query=sparse_query, 

433 using=self.sparse_runtime.sparse_vector_name, 

434 filter=query_filter, 

435 limit=prefetch_limit, 

436 ), 

437 ], 

438 query=models.FusionQuery(fusion=models.Fusion.RRF), 

439 limit=limit, 

440 with_payload=True, 

441 ) 

442 _used_qdrant_hybrid_ctx.set(True) 

443 return query_response.points 

444 

445 async def _run_dense_query( 

446 self, 

447 query_embedding: list[float], 

448 query_filter, 

449 search_params, 

450 caps: CollectionVectorCapabilities, 

451 limit: int, 

452 ) -> list: 

453 """Plain dense vector search — used when hybrid isn't available or has failed.""" 

454 query_kwargs: dict[str, Any] = { 

455 "collection_name": self.collection_name, 

456 "query": query_embedding, 

457 "limit": limit, 

458 "score_threshold": self.min_score, 

459 "search_params": search_params, 

460 "query_filter": query_filter, 

461 "with_payload": True, 

462 } 

463 using = self._dense_using(caps) 

464 if using: 

465 query_kwargs["using"] = using 

466 query_response = await self.qdrant_client.query_points(**query_kwargs) 

467 return query_response.points 

468 

469 @staticmethod 

470 def _extract_hits(results: list) -> list[dict[str, Any]]: 

471 """Project Qdrant scored points to the dict shape consumed downstream.""" 

472 extracted: list[dict[str, Any]] = [] 

473 for hit in results: 

474 payload = getattr(hit, "payload", None) or {} 

475 extracted.append( 

476 { 

477 "score": hit.score, 

478 "text": payload.get("content", ""), 

479 "metadata": payload.get("metadata", {}), 

480 "source_type": payload.get("source_type", "unknown"), 

481 "title": payload.get("title", ""), 

482 "url": payload.get("url", ""), 

483 "document_id": payload.get("document_id", ""), 

484 "source": payload.get("source", ""), 

485 "created_at": payload.get("created_at", ""), 

486 "updated_at": payload.get("updated_at", ""), 

487 "contextual_content": payload.get("contextual_content", ""), 

488 } 

489 ) 

490 return extracted 

491 

492 def get_cache_stats(self) -> dict[str, Any]: 

493 """Get cache performance statistics. 

494 

495 Returns: 

496 Dictionary with cache hit rate, size, and other metrics 

497 """ 

498 total_requests = self._cache_hits + self._cache_misses 

499 hit_rate = ( 

500 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0 

501 ) 

502 

503 return { 

504 "cache_enabled": self.cache_enabled, 

505 "cache_hits": self._cache_hits, 

506 "cache_misses": self._cache_misses, 

507 "hit_rate_percent": round(hit_rate, 2), 

508 "cache_size": len(self._search_cache), 

509 "cache_max_size": self.cache_max_size, 

510 "cache_ttl_seconds": self.cache_ttl, 

511 } 

512 

513 def clear_cache(self) -> None: 

514 """Clear all cached search results.""" 

515 self._search_cache.clear() 

516 self.logger.info("Search result cache cleared") 

517 

518 def _build_filter( 

519 self, project_ids: list[str] | None = None 

520 ) -> qdrant_models.Filter | None: 

521 """Legacy method for backward compatibility - use FieldQueryParser instead. 

522 

523 Args: 

524 project_ids: Optional project ID filters 

525 

526 Returns: 

527 Qdrant Filter object or None 

528 """ 

529 if project_ids: 

530 return models.Filter( 

531 must=[ 

532 models.FieldCondition( 

533 key="project_id", match=models.MatchAny(any=project_ids) 

534 ) 

535 ] 

536 ) 

537 return None 

538 

539 def build_filter( 

540 self, project_ids: list[str] | None = None 

541 ) -> qdrant_models.Filter | None: 

542 """Public wrapper for building a Qdrant filter for project constraints. 

543 

544 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This 

545 method exists to expose project filter building via a public API and wraps the 

546 legacy `_build_filter` implementation for compatibility. 

547 

548 Args: 

549 project_ids: Optional list of project IDs to filter by. 

550 

551 Returns: 

552 A Qdrant `models.Filter` or `None` if no filtering is needed. 

553 """ 

554 return self._build_filter(project_ids)