Coverage for src / qdrant_loader_mcp_server / search / components / vector_search_service.py: 94%
116 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
1"""Vector search service for hybrid search."""
3from __future__ import annotations
5import hashlib
6import time
7from asyncio import Lock
8from dataclasses import dataclass
9from typing import TYPE_CHECKING, Any
11if TYPE_CHECKING:
12 from qdrant_client import AsyncQdrantClient
13 from qdrant_client.http import models as qdrant_models
15from ...utils.logging import LoggingConfig
16from .field_query_parser import FieldQueryParser
19@dataclass
20class FilterResult:
21 score: float
22 payload: dict
25class VectorSearchService:
26 """Handles vector search operations using Qdrant."""
28 def __init__(
29 self,
30 qdrant_client: AsyncQdrantClient,
31 collection_name: str,
32 min_score: float = 0.3,
33 cache_enabled: bool = True,
34 cache_ttl: int = 300,
35 cache_max_size: int = 500,
36 hnsw_ef: int = 128,
37 use_exact_search: bool = False,
38 *,
39 embeddings_provider: Any | None = None,
40 openai_client: Any | None = None,
41 ):
42 """Initialize the vector search service.
44 Args:
45 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient)
46 openai_client: OpenAI client instance
47 collection_name: Name of the Qdrant collection
48 min_score: Minimum score threshold
49 cache_enabled: Whether to enable search result caching
50 cache_ttl: Cache time-to-live in seconds
51 cache_max_size: Maximum number of cached results
52 """
53 self.qdrant_client = qdrant_client
54 self.embeddings_provider = embeddings_provider
55 self.openai_client = openai_client
56 self.collection_name = collection_name
57 self.min_score = min_score
59 # Search result caching configuration
60 self.cache_enabled = cache_enabled
61 self.cache_ttl = cache_ttl
62 self.cache_max_size = cache_max_size
63 self._search_cache: dict[str, dict[str, Any]] = {}
64 self._cache_lock: Lock = Lock()
66 # Cache performance metrics
67 self._cache_hits = 0
68 self._cache_misses = 0
70 # Field query parser for handling field:value syntax
71 self.field_parser = FieldQueryParser()
73 self.logger = LoggingConfig.get_logger(__name__)
75 # Qdrant search parameters
76 self.hnsw_ef = hnsw_ef
77 self.use_exact_search = use_exact_search
79 def _generate_cache_key(
80 self, query: str, limit: int, project_ids: list[str] | None = None
81 ) -> str:
82 """Generate a cache key for search parameters.
84 Args:
85 query: Search query
86 limit: Maximum number of results
87 project_ids: Optional project ID filters
89 Returns:
90 SHA256 hash of search parameters for cache key
91 """
92 # Create a deterministic string from search parameters
93 project_str = ",".join(sorted(project_ids)) if project_ids else "none"
94 cache_input = (
95 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}"
96 )
97 return hashlib.sha256(cache_input.encode()).hexdigest()
99 def _cleanup_expired_cache(self) -> None:
100 """Remove expired entries from cache."""
101 if not self.cache_enabled:
102 return
104 current_time = time.time()
105 expired_keys = [
106 key
107 for key, value in self._search_cache.items()
108 if current_time - value["timestamp"] > self.cache_ttl
109 ]
111 for key in expired_keys:
112 del self._search_cache[key]
114 # Also enforce max size limit
115 if len(self._search_cache) > self.cache_max_size:
116 # Remove oldest entries (simple FIFO eviction)
117 sorted_items = sorted(
118 self._search_cache.items(), key=lambda x: x[1]["timestamp"]
119 )
120 items_to_remove = len(self._search_cache) - self.cache_max_size
121 for key, _ in sorted_items[:items_to_remove]:
122 del self._search_cache[key]
124 async def get_embedding(self, text: str) -> list[float]:
125 """Get embedding for text using OpenAI client when available, else provider.
127 Args:
128 text: Text to get embedding for
130 Returns:
131 List of embedding values
133 Raises:
134 Exception: If embedding generation fails
135 """
136 # Prefer provider when available
137 if self.embeddings_provider is not None:
138 try:
139 # Accept either a provider (with .embeddings()) or a direct embeddings client
140 client = (
141 self.embeddings_provider.embeddings()
142 if hasattr(self.embeddings_provider, "embeddings")
143 else self.embeddings_provider
144 )
145 vectors = await client.embed([text])
146 return vectors[0]
147 except Exception as e:
148 self.logger.error("Provider embeddings failed", error=str(e))
149 raise
151 # Fallback to OpenAI client when provider is not configured
152 if self.openai_client is not None:
153 try:
154 response = await self.openai_client.embeddings.create(
155 model="text-embedding-3-small",
156 input=text,
157 )
158 return response.data[0].embedding
159 except Exception as e:
160 # Do not fall back silently; propagate error as tests expect
161 self.logger.error("Failed to get embedding", error=str(e))
162 raise
164 # Nothing configured
165 raise RuntimeError("No embeddings provider or OpenAI client configured")
167 async def vector_search(
168 self, query: str, limit: int, project_ids: list[str] | None = None
169 ) -> list[dict[str, Any]]:
170 """Perform vector search using Qdrant with caching support.
172 Args:
173 query: Search query
174 limit: Maximum number of results
175 project_ids: Optional project ID filters
177 Returns:
178 List of search results with scores, text, metadata, and source_type
179 """
180 # Generate cache key and check cache first
181 cache_key = self._generate_cache_key(query, limit, project_ids)
183 if self.cache_enabled:
184 # Guard cache reads/cleanup with the async lock
185 async with self._cache_lock:
186 self._cleanup_expired_cache()
188 # Check cache for existing results
189 cached_entry = self._search_cache.get(cache_key)
190 if cached_entry is not None:
191 current_time = time.time()
193 # Verify cache entry is still valid
194 if current_time - cached_entry["timestamp"] <= self.cache_ttl:
195 self._cache_hits += 1
196 self.logger.debug(
197 "Search cache hit",
198 query=query[:50], # Truncate for logging
199 cache_hits=self._cache_hits,
200 cache_misses=self._cache_misses,
201 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%",
202 )
203 return cached_entry["results"]
205 # Cache miss - perform actual search
206 self._cache_misses += 1
208 self.logger.debug(
209 "Search cache miss - performing QDrant search",
210 query=query[:50], # Truncate for logging
211 cache_hits=self._cache_hits,
212 cache_misses=self._cache_misses,
213 )
215 # ✅ Parse query for field-specific filters
216 parsed_query = self.field_parser.parse_query(query)
217 self.logger.debug(
218 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'"
219 )
221 # Determine search strategy based on parsed query
222 if self.field_parser.should_use_filter_only(parsed_query):
223 # Filter-only search (exact field matching)
224 self.logger.debug("Using filter-only search for exact field matching")
225 query_filter = self.field_parser.create_qdrant_filter(
226 parsed_query.field_queries, project_ids
227 )
229 # For filter-only searches, use scroll with filter instead of vector search
230 scroll_results = await self.qdrant_client.scroll(
231 collection_name=self.collection_name,
232 limit=limit,
233 scroll_filter=query_filter,
234 with_payload=True,
235 with_vectors=False,
236 )
238 results = []
239 for point in scroll_results[
240 0
241 ]: # scroll_results is (points, next_page_offset)
242 results.append(FilterResult(1.0, point.payload))
243 else:
244 # Hybrid search (vector search + field filters)
245 from qdrant_client.http import models
247 search_query = parsed_query.text_query if parsed_query.text_query else query
248 query_embedding = await self.get_embedding(search_query)
250 search_params = models.SearchParams(
251 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search)
252 )
254 # Combine field filters with project filters
255 query_filter = self.field_parser.create_qdrant_filter(
256 parsed_query.field_queries, project_ids
257 )
259 # Use query_points API (qdrant-client 1.10+)
260 query_response = await self.qdrant_client.query_points(
261 collection_name=self.collection_name,
262 query=query_embedding,
263 limit=limit,
264 score_threshold=self.min_score,
265 search_params=search_params,
266 query_filter=query_filter,
267 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data
268 )
269 results = query_response.points
271 extracted_results = []
272 for hit in results:
273 extracted = {
274 "score": hit.score,
275 "text": hit.payload.get("content", "") if hit.payload else "",
276 "metadata": hit.payload.get("metadata", {}) if hit.payload else {},
277 "source_type": (
278 hit.payload.get("source_type", "unknown")
279 if hit.payload
280 else "unknown"
281 ),
282 # Extract fields directly from Qdrant payload
283 "title": hit.payload.get("title", "") if hit.payload else "",
284 "url": hit.payload.get("url", "") if hit.payload else "",
285 "document_id": (
286 hit.payload.get("document_id", "") if hit.payload else ""
287 ),
288 "source": hit.payload.get("source", "") if hit.payload else "",
289 "created_at": hit.payload.get("created_at", "") if hit.payload else "",
290 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "",
291 }
293 extracted_results.append(extracted)
295 # Store results in cache if caching is enabled
296 if self.cache_enabled:
297 async with self._cache_lock:
298 self._search_cache[cache_key] = {
299 "results": extracted_results,
300 "timestamp": time.time(),
301 }
303 self.logger.debug(
304 "Cached search results",
305 query=query[:50],
306 results_count=len(extracted_results),
307 cache_size=len(self._search_cache),
308 )
310 return extracted_results
312 def get_cache_stats(self) -> dict[str, Any]:
313 """Get cache performance statistics.
315 Returns:
316 Dictionary with cache hit rate, size, and other metrics
317 """
318 total_requests = self._cache_hits + self._cache_misses
319 hit_rate = (
320 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0
321 )
323 return {
324 "cache_enabled": self.cache_enabled,
325 "cache_hits": self._cache_hits,
326 "cache_misses": self._cache_misses,
327 "hit_rate_percent": round(hit_rate, 2),
328 "cache_size": len(self._search_cache),
329 "cache_max_size": self.cache_max_size,
330 "cache_ttl_seconds": self.cache_ttl,
331 }
333 def clear_cache(self) -> None:
334 """Clear all cached search results."""
335 self._search_cache.clear()
336 self.logger.info("Search result cache cleared")
338 def _build_filter(
339 self, project_ids: list[str] | None = None
340 ) -> qdrant_models.Filter | None:
341 """Legacy method for backward compatibility - use FieldQueryParser instead.
343 Args:
344 project_ids: Optional project ID filters
346 Returns:
347 Qdrant Filter object or None
348 """
349 if project_ids:
350 from qdrant_client.http import models
352 return models.Filter(
353 must=[
354 models.FieldCondition(
355 key="project_id", match=models.MatchAny(any=project_ids)
356 )
357 ]
358 )
359 return None
361 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter()
363 def build_filter(
364 self, project_ids: list[str] | None = None
365 ) -> qdrant_models.Filter | None:
366 """Public wrapper for building a Qdrant filter for project constraints.
368 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This
369 method exists to expose project filter building via a public API and wraps the
370 legacy `_build_filter` implementation for compatibility.
372 Args:
373 project_ids: Optional list of project IDs to filter by.
375 Returns:
376 A Qdrant `models.Filter` or `None` if no filtering is needed.
377 """
378 return self._build_filter(project_ids)