Coverage for src / qdrant_loader_mcp_server / search / components / vector_search_service.py: 95%
183 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:38 +0000
1"""Vector search service for hybrid search."""
3from __future__ import annotations
5import asyncio
6import hashlib
7import time
8from asyncio import Lock
9from contextvars import ContextVar
10from dataclasses import dataclass
11from typing import TYPE_CHECKING, Any
13if TYPE_CHECKING:
14 from qdrant_client import AsyncQdrantClient
15 from qdrant_client.http import models as qdrant_models
17from qdrant_client.http import models
18from qdrant_loader_core.config import (
19 CollectionVectorCapabilities,
20 parse_collection_capabilities,
21)
22from qdrant_loader_core.sparse import get_sparse_encoder
24from ...utils.logging import LoggingConfig
25from ..sparse_config import load_sparse_runtime_config
26from .field_query_parser import FieldQueryParser
28# Task-local flag set by vector_search when Qdrant fusion is used for the
29# current query. ContextVar isolates concurrent searches that share the same
30# VectorSearchService instance.
31_used_qdrant_hybrid_ctx: ContextVar[bool] = ContextVar(
32 "vector_search_used_qdrant_hybrid", default=False
33)
36@dataclass
37class FilterResult:
38 score: float
39 payload: dict
42class VectorSearchService:
43 """Handles vector search operations using Qdrant."""
45 def __init__(
46 self,
47 qdrant_client: AsyncQdrantClient,
48 collection_name: str,
49 min_score: float = 0.3,
50 cache_enabled: bool = True,
51 cache_ttl: int = 300,
52 cache_max_size: int = 500,
53 hnsw_ef: int = 128,
54 use_exact_search: bool = False,
55 *,
56 embeddings_provider: Any | None = None,
57 openai_client: Any | None = None,
58 embedding_model: str = "text-embedding-3-small",
59 ):
60 """Initialize the vector search service.
62 Args:
63 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient)
64 openai_client: OpenAI client instance
65 collection_name: Name of the Qdrant collection
66 min_score: Minimum score threshold
67 cache_enabled: Whether to enable search result caching
68 cache_ttl: Cache time-to-live in seconds
69 cache_max_size: Maximum number of cached results
70 """
71 self.qdrant_client = qdrant_client
72 self.embeddings_provider = embeddings_provider
73 self.openai_client = openai_client
74 self.collection_name = collection_name
75 self.embedding_model = embedding_model
76 self.min_score = min_score
78 # Search result caching configuration
79 self.cache_enabled = cache_enabled
80 self.cache_ttl = cache_ttl
81 self.cache_max_size = cache_max_size
82 self._search_cache: dict[str, dict[str, Any]] = {}
83 self._cache_lock: Lock = Lock()
85 # Cache performance metrics
86 self._cache_hits = 0
87 self._cache_misses = 0
89 # Field query parser for handling field:value syntax
90 self.field_parser = FieldQueryParser()
92 self.logger = LoggingConfig.get_logger(__name__)
94 self.sparse_runtime = load_sparse_runtime_config()
95 self._collection_capabilities: CollectionVectorCapabilities | None = None
96 self._capabilities_lock: Lock = Lock()
98 # Qdrant search parameters
99 self.hnsw_ef = hnsw_ef
100 self.use_exact_search = use_exact_search
102 def _generate_cache_key(
103 self, query: str, limit: int, project_ids: list[str] | None = None
104 ) -> str:
105 """Generate a cache key for search parameters.
107 Args:
108 query: Search query
109 limit: Maximum number of results
110 project_ids: Optional project ID filters
112 Returns:
113 SHA256 hash of search parameters for cache key
114 """
115 # Create a deterministic string from search parameters
116 project_str = ",".join(sorted(project_ids)) if project_ids else "none"
117 cache_input = (
118 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}"
119 )
120 return hashlib.sha256(cache_input.encode()).hexdigest()
122 def _cleanup_expired_cache(self) -> None:
123 """Remove expired entries from cache."""
124 if not self.cache_enabled:
125 return
127 current_time = time.time()
128 expired_keys = [
129 key
130 for key, value in self._search_cache.items()
131 if current_time - value["timestamp"] > self.cache_ttl
132 ]
134 for key in expired_keys:
135 del self._search_cache[key]
137 # Also enforce max size limit
138 if len(self._search_cache) > self.cache_max_size:
139 # Remove oldest entries (simple FIFO eviction)
140 sorted_items = sorted(
141 self._search_cache.items(), key=lambda x: x[1]["timestamp"]
142 )
143 items_to_remove = len(self._search_cache) - self.cache_max_size
144 for key, _ in sorted_items[:items_to_remove]:
145 del self._search_cache[key]
147 async def _get_collection_capabilities(self) -> CollectionVectorCapabilities:
148 """Probe the collection for named-dense and sparse vector support.
150 Transient ``get_collection`` failures are logged and returned as
151 empty capabilities for the current call only — they are never cached,
152 so the next call retries.
153 """
154 if self._collection_capabilities is not None:
155 return self._collection_capabilities
157 async with self._capabilities_lock:
158 if self._collection_capabilities is not None:
159 return self._collection_capabilities
160 try:
161 info = await self.qdrant_client.get_collection(
162 collection_name=self.collection_name
163 )
164 except Exception as e:
165 # Don't cache: a transient outage would otherwise pin every
166 # subsequent query to the dense fallback path for this
167 # service's lifetime.
168 self.logger.warning(
169 "Failed to inspect Qdrant collection schema; assuming dense-only",
170 collection=self.collection_name,
171 error=str(e),
172 )
173 return CollectionVectorCapabilities()
174 self._collection_capabilities = parse_collection_capabilities(
175 info, self.sparse_runtime
176 )
177 return self._collection_capabilities
179 def _dense_using(self, caps: CollectionVectorCapabilities) -> str | None:
180 """Return the named-dense vector key, or None for legacy unnamed collections."""
181 return self.sparse_runtime.dense_vector_name if caps.has_named_dense else None
183 def _hybrid_query_active(self, caps: CollectionVectorCapabilities) -> bool:
184 """Combine collection schema with runtime config to decide on server-side fusion."""
185 return (
186 caps.hybrid_ready
187 and self.sparse_runtime.enabled
188 and self.sparse_runtime.use_qdrant_hybrid
189 )
191 async def supports_qdrant_hybrid(self) -> bool:
192 """Return True if the collection is configured for Qdrant dense+sparse fusion.
194 Used by HybridPipeline to decide whether to skip the separate keyword
195 search; calling this lets the pipeline preserve parallelism in the
196 dense-only path.
197 """
198 caps = await self._get_collection_capabilities()
199 return self._hybrid_query_active(caps)
201 def _encode_sparse_query(self, text: str):
202 sparse = get_sparse_encoder(self.sparse_runtime.model).encode_query(text)
203 if sparse.is_empty():
204 return None
205 return models.SparseVector(indices=sparse.indices, values=sparse.values)
207 def used_qdrant_hybrid_last_query(self) -> bool:
208 """Return whether the most recent vector_search in this task used Qdrant fusion.
210 Backed by a ContextVar so concurrent searches across tasks (e.g. multiple
211 MCP requests sharing this service) don't clobber each other.
212 """
213 return _used_qdrant_hybrid_ctx.get()
215 async def get_embedding(self, text: str) -> list[float]:
216 """Get embedding for text using OpenAI client when available, else provider.
218 Args:
219 text: Text to get embedding for
221 Returns:
222 List of embedding values
224 Raises:
225 Exception: If embedding generation fails
226 """
227 # Prefer provider when available
228 if self.embeddings_provider is not None:
229 # Accept either a provider (with .embeddings()) or a direct embeddings client
230 client = (
231 self.embeddings_provider.embeddings()
232 if hasattr(self.embeddings_provider, "embeddings")
233 else self.embeddings_provider
234 )
235 for _ in range(3):
236 try:
237 vectors = await client.embed([text])
238 return vectors[0]
239 except Exception as e:
240 self.logger.warning(
241 "Provider embedding failed, retrying...", error=str(e)
242 )
243 await asyncio.sleep(0.5)
245 # Fallback to OpenAI (to keep backward compatibility & pass tests)
246 if self.openai_client is not None:
247 try:
248 response = await self.openai_client.embeddings.create(
249 model=self.embedding_model,
250 input=text,
251 )
252 return response.data[0].embedding
253 except Exception as e:
254 self.logger.error("OpenAI fallback failed", error=str(e))
255 raise
257 # Nothing configured
258 raise RuntimeError("No embeddings provider or OpenAI client configured")
260 async def vector_search(
261 self, query: str, limit: int, project_ids: list[str] | None = None
262 ) -> list[dict[str, Any]]:
263 """Perform vector search using Qdrant with caching support.
265 Args:
266 query: Search query
267 limit: Maximum number of results
268 project_ids: Optional project ID filters
270 Returns:
271 List of search results with scores, text, metadata, and source_type
272 """
273 # Reset before the cache lookup so used_qdrant_hybrid_last_query()
274 # reflects this call rather than the previous task-local value, even
275 # when we return a cached result.
276 _used_qdrant_hybrid_ctx.set(False)
277 cache_key = self._generate_cache_key(query, limit, project_ids)
278 cached = await self._cache_get_if_valid(cache_key, query)
279 if cached is not None:
280 return cached
282 parsed_query = self.field_parser.parse_query(query)
283 self.logger.debug(
284 f"Parsed query: {len(parsed_query.field_queries)} field queries, "
285 f"text: '{parsed_query.text_query}'"
286 )
288 if self.field_parser.should_use_filter_only(parsed_query):
289 results = await self._run_filter_only_search(
290 parsed_query, project_ids, limit
291 )
292 else:
293 results = await self._run_vector_search(
294 parsed_query, project_ids, query, limit
295 )
297 extracted = self._extract_hits(results)
298 await self._cache_put(cache_key, extracted, query)
299 return extracted
301 async def _cache_get_if_valid(
302 self, cache_key: str, query: str
303 ) -> list[dict[str, Any]] | None:
304 """Return cached results if present and not expired; otherwise increment the miss counter."""
305 if self.cache_enabled:
306 async with self._cache_lock:
307 self._cleanup_expired_cache()
308 cached_entry = self._search_cache.get(cache_key)
309 if (
310 cached_entry is not None
311 and time.time() - cached_entry["timestamp"] <= self.cache_ttl
312 ):
313 self._cache_hits += 1
314 self.logger.debug(
315 "Search cache hit",
316 query=query[:50],
317 cache_hits=self._cache_hits,
318 cache_misses=self._cache_misses,
319 hit_rate=(
320 f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%"
321 ),
322 )
323 return cached_entry["results"]
325 self._cache_misses += 1
326 self.logger.debug(
327 "Search cache miss - performing QDrant search",
328 query=query[:50],
329 cache_hits=self._cache_hits,
330 cache_misses=self._cache_misses,
331 )
332 return None
334 async def _cache_put(
335 self, cache_key: str, results: list[dict[str, Any]], query: str
336 ) -> None:
337 """Store ``results`` under ``cache_key`` when caching is enabled."""
338 if not self.cache_enabled:
339 return
340 async with self._cache_lock:
341 self._search_cache[cache_key] = {
342 "results": results,
343 "timestamp": time.time(),
344 }
345 self.logger.debug(
346 "Cached search results",
347 query=query[:50],
348 results_count=len(results),
349 cache_size=len(self._search_cache),
350 )
352 async def _run_filter_only_search(
353 self,
354 parsed_query,
355 project_ids: list[str] | None,
356 limit: int,
357 ) -> list[FilterResult]:
358 """Scroll-based exact-match path used when the query contains only field filters."""
359 self.logger.debug("Using filter-only search for exact field matching")
360 query_filter = self.field_parser.create_qdrant_filter(
361 parsed_query.field_queries, project_ids
362 )
363 scroll_results = await self.qdrant_client.scroll(
364 collection_name=self.collection_name,
365 limit=limit,
366 scroll_filter=query_filter,
367 with_payload=True,
368 with_vectors=False,
369 )
370 return [FilterResult(1.0, point.payload) for point in scroll_results[0]]
372 async def _run_vector_search(
373 self,
374 parsed_query,
375 project_ids: list[str] | None,
376 original_query: str,
377 limit: int,
378 ) -> list:
379 """Dispatch to either the Qdrant hybrid query or the dense-only query."""
380 search_query = parsed_query.text_query or original_query
381 query_embedding = await self.get_embedding(search_query)
382 search_params = models.SearchParams(
383 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search)
384 )
385 query_filter = self.field_parser.create_qdrant_filter(
386 parsed_query.field_queries, project_ids
387 )
388 # Routing is fully determined by the probed collection schema and the
389 # runtime config — no silent fallback. If the collection supports
390 # hybrid retrieval, we use hybrid; otherwise dense. Hybrid query
391 # failures propagate so operators see them instead of degraded results.
392 caps = await self._get_collection_capabilities()
393 if self._hybrid_query_active(caps):
394 return await self._run_qdrant_hybrid_query(
395 query_embedding,
396 search_query,
397 query_filter,
398 search_params,
399 caps,
400 limit,
401 )
402 return await self._run_dense_query(
403 query_embedding, query_filter, search_params, caps, limit
404 )
406 async def _run_qdrant_hybrid_query(
407 self,
408 query_embedding: list[float],
409 search_query: str,
410 query_filter,
411 search_params,
412 caps: CollectionVectorCapabilities,
413 limit: int,
414 ) -> list:
415 """Issue a server-side RRF fusion query over the dense + sparse prefetches."""
416 sparse_query = self._encode_sparse_query(search_query)
417 if sparse_query is None:
418 raise ValueError("Sparse query generation returned empty vector")
420 prefetch_limit = limit * 3
421 query_response = await self.qdrant_client.query_points(
422 collection_name=self.collection_name,
423 prefetch=[
424 models.Prefetch(
425 query=query_embedding,
426 using=self._dense_using(caps),
427 filter=query_filter,
428 params=search_params,
429 limit=prefetch_limit,
430 ),
431 models.Prefetch(
432 query=sparse_query,
433 using=self.sparse_runtime.sparse_vector_name,
434 filter=query_filter,
435 limit=prefetch_limit,
436 ),
437 ],
438 query=models.FusionQuery(fusion=models.Fusion.RRF),
439 limit=limit,
440 with_payload=True,
441 )
442 _used_qdrant_hybrid_ctx.set(True)
443 return query_response.points
445 async def _run_dense_query(
446 self,
447 query_embedding: list[float],
448 query_filter,
449 search_params,
450 caps: CollectionVectorCapabilities,
451 limit: int,
452 ) -> list:
453 """Plain dense vector search — used when hybrid isn't available or has failed."""
454 query_kwargs: dict[str, Any] = {
455 "collection_name": self.collection_name,
456 "query": query_embedding,
457 "limit": limit,
458 "score_threshold": self.min_score,
459 "search_params": search_params,
460 "query_filter": query_filter,
461 "with_payload": True,
462 }
463 using = self._dense_using(caps)
464 if using:
465 query_kwargs["using"] = using
466 query_response = await self.qdrant_client.query_points(**query_kwargs)
467 return query_response.points
469 @staticmethod
470 def _extract_hits(results: list) -> list[dict[str, Any]]:
471 """Project Qdrant scored points to the dict shape consumed downstream."""
472 extracted: list[dict[str, Any]] = []
473 for hit in results:
474 payload = getattr(hit, "payload", None) or {}
475 extracted.append(
476 {
477 "score": hit.score,
478 "text": payload.get("content", ""),
479 "metadata": payload.get("metadata", {}),
480 "source_type": payload.get("source_type", "unknown"),
481 "title": payload.get("title", ""),
482 "url": payload.get("url", ""),
483 "document_id": payload.get("document_id", ""),
484 "source": payload.get("source", ""),
485 "created_at": payload.get("created_at", ""),
486 "updated_at": payload.get("updated_at", ""),
487 "contextual_content": payload.get("contextual_content", ""),
488 }
489 )
490 return extracted
492 def get_cache_stats(self) -> dict[str, Any]:
493 """Get cache performance statistics.
495 Returns:
496 Dictionary with cache hit rate, size, and other metrics
497 """
498 total_requests = self._cache_hits + self._cache_misses
499 hit_rate = (
500 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0
501 )
503 return {
504 "cache_enabled": self.cache_enabled,
505 "cache_hits": self._cache_hits,
506 "cache_misses": self._cache_misses,
507 "hit_rate_percent": round(hit_rate, 2),
508 "cache_size": len(self._search_cache),
509 "cache_max_size": self.cache_max_size,
510 "cache_ttl_seconds": self.cache_ttl,
511 }
513 def clear_cache(self) -> None:
514 """Clear all cached search results."""
515 self._search_cache.clear()
516 self.logger.info("Search result cache cleared")
518 def _build_filter(
519 self, project_ids: list[str] | None = None
520 ) -> qdrant_models.Filter | None:
521 """Legacy method for backward compatibility - use FieldQueryParser instead.
523 Args:
524 project_ids: Optional project ID filters
526 Returns:
527 Qdrant Filter object or None
528 """
529 if project_ids:
530 return models.Filter(
531 must=[
532 models.FieldCondition(
533 key="project_id", match=models.MatchAny(any=project_ids)
534 )
535 ]
536 )
537 return None
539 def build_filter(
540 self, project_ids: list[str] | None = None
541 ) -> qdrant_models.Filter | None:
542 """Public wrapper for building a Qdrant filter for project constraints.
544 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This
545 method exists to expose project filter building via a public API and wraps the
546 legacy `_build_filter` implementation for compatibility.
548 Args:
549 project_ids: Optional list of project IDs to filter by.
551 Returns:
552 A Qdrant `models.Filter` or `None` if no filtering is needed.
553 """
554 return self._build_filter(project_ids)