Coverage for src/qdrant_loader_mcp_server/search/components/vector_search_service.py: 94%
115 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
1"""Vector search service for hybrid search."""
3import hashlib
4import time
5from asyncio import Lock
6from dataclasses import dataclass
7from typing import Any
9from qdrant_client import AsyncQdrantClient
10from qdrant_client.http import models
12from ...utils.logging import LoggingConfig
13from .field_query_parser import FieldQueryParser
16@dataclass
17class FilterResult:
18 score: float
19 payload: dict
22class VectorSearchService:
23 """Handles vector search operations using Qdrant."""
25 def __init__(
26 self,
27 qdrant_client: AsyncQdrantClient,
28 collection_name: str,
29 min_score: float = 0.3,
30 cache_enabled: bool = True,
31 cache_ttl: int = 300,
32 cache_max_size: int = 500,
33 hnsw_ef: int = 128,
34 use_exact_search: bool = False,
35 *,
36 embeddings_provider: Any | None = None,
37 openai_client: Any | None = None,
38 ):
39 """Initialize the vector search service.
41 Args:
42 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient)
43 openai_client: OpenAI client instance
44 collection_name: Name of the Qdrant collection
45 min_score: Minimum score threshold
46 cache_enabled: Whether to enable search result caching
47 cache_ttl: Cache time-to-live in seconds
48 cache_max_size: Maximum number of cached results
49 """
50 self.qdrant_client = qdrant_client
51 self.embeddings_provider = embeddings_provider
52 self.openai_client = openai_client
53 self.collection_name = collection_name
54 self.min_score = min_score
56 # Search result caching configuration
57 self.cache_enabled = cache_enabled
58 self.cache_ttl = cache_ttl
59 self.cache_max_size = cache_max_size
60 self._search_cache: dict[str, dict[str, Any]] = {}
61 self._cache_lock: Lock = Lock()
63 # Cache performance metrics
64 self._cache_hits = 0
65 self._cache_misses = 0
67 # Field query parser for handling field:value syntax
68 self.field_parser = FieldQueryParser()
70 self.logger = LoggingConfig.get_logger(__name__)
72 # Qdrant search parameters
73 self.hnsw_ef = hnsw_ef
74 self.use_exact_search = use_exact_search
76 def _generate_cache_key(
77 self, query: str, limit: int, project_ids: list[str] | None = None
78 ) -> str:
79 """Generate a cache key for search parameters.
81 Args:
82 query: Search query
83 limit: Maximum number of results
84 project_ids: Optional project ID filters
86 Returns:
87 SHA256 hash of search parameters for cache key
88 """
89 # Create a deterministic string from search parameters
90 project_str = ",".join(sorted(project_ids)) if project_ids else "none"
91 cache_input = (
92 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}"
93 )
94 return hashlib.sha256(cache_input.encode()).hexdigest()
96 def _cleanup_expired_cache(self) -> None:
97 """Remove expired entries from cache."""
98 if not self.cache_enabled:
99 return
101 current_time = time.time()
102 expired_keys = [
103 key
104 for key, value in self._search_cache.items()
105 if current_time - value["timestamp"] > self.cache_ttl
106 ]
108 for key in expired_keys:
109 del self._search_cache[key]
111 # Also enforce max size limit
112 if len(self._search_cache) > self.cache_max_size:
113 # Remove oldest entries (simple FIFO eviction)
114 sorted_items = sorted(
115 self._search_cache.items(), key=lambda x: x[1]["timestamp"]
116 )
117 items_to_remove = len(self._search_cache) - self.cache_max_size
118 for key, _ in sorted_items[:items_to_remove]:
119 del self._search_cache[key]
121 async def get_embedding(self, text: str) -> list[float]:
122 """Get embedding for text using OpenAI client when available, else provider.
124 Args:
125 text: Text to get embedding for
127 Returns:
128 List of embedding values
130 Raises:
131 Exception: If embedding generation fails
132 """
133 # Prefer provider when available
134 if self.embeddings_provider is not None:
135 try:
136 # Accept either a provider (with .embeddings()) or a direct embeddings client
137 client = (
138 self.embeddings_provider.embeddings()
139 if hasattr(self.embeddings_provider, "embeddings")
140 else self.embeddings_provider
141 )
142 vectors = await client.embed([text])
143 return vectors[0]
144 except Exception as e:
145 self.logger.error("Provider embeddings failed", error=str(e))
146 raise
148 # Fallback to OpenAI client when provider is not configured
149 if self.openai_client is not None:
150 try:
151 response = await self.openai_client.embeddings.create(
152 model="text-embedding-3-small",
153 input=text,
154 )
155 return response.data[0].embedding
156 except Exception as e:
157 # Do not fall back silently; propagate error as tests expect
158 self.logger.error("Failed to get embedding", error=str(e))
159 raise
161 # Nothing configured
162 raise RuntimeError("No embeddings provider or OpenAI client configured")
164 async def vector_search(
165 self, query: str, limit: int, project_ids: list[str] | None = None
166 ) -> list[dict[str, Any]]:
167 """Perform vector search using Qdrant with caching support.
169 Args:
170 query: Search query
171 limit: Maximum number of results
172 project_ids: Optional project ID filters
174 Returns:
175 List of search results with scores, text, metadata, and source_type
176 """
177 # Generate cache key and check cache first
178 cache_key = self._generate_cache_key(query, limit, project_ids)
180 if self.cache_enabled:
181 # Guard cache reads/cleanup with the async lock
182 async with self._cache_lock:
183 self._cleanup_expired_cache()
185 # Check cache for existing results
186 cached_entry = self._search_cache.get(cache_key)
187 if cached_entry is not None:
188 current_time = time.time()
190 # Verify cache entry is still valid
191 if current_time - cached_entry["timestamp"] <= self.cache_ttl:
192 self._cache_hits += 1
193 self.logger.debug(
194 "Search cache hit",
195 query=query[:50], # Truncate for logging
196 cache_hits=self._cache_hits,
197 cache_misses=self._cache_misses,
198 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%",
199 )
200 return cached_entry["results"]
202 # Cache miss - perform actual search
203 self._cache_misses += 1
205 self.logger.debug(
206 "Search cache miss - performing QDrant search",
207 query=query[:50], # Truncate for logging
208 cache_hits=self._cache_hits,
209 cache_misses=self._cache_misses,
210 )
212 # ✅ Parse query for field-specific filters
213 parsed_query = self.field_parser.parse_query(query)
214 self.logger.debug(
215 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'"
216 )
218 # Determine search strategy based on parsed query
219 if self.field_parser.should_use_filter_only(parsed_query):
220 # Filter-only search (exact field matching)
221 self.logger.debug("Using filter-only search for exact field matching")
222 query_filter = self.field_parser.create_qdrant_filter(
223 parsed_query.field_queries, project_ids
224 )
226 # For filter-only searches, use scroll with filter instead of vector search
227 scroll_results = await self.qdrant_client.scroll(
228 collection_name=self.collection_name,
229 limit=limit,
230 scroll_filter=query_filter,
231 with_payload=True,
232 with_vectors=False,
233 )
235 results = []
236 for point in scroll_results[
237 0
238 ]: # scroll_results is (points, next_page_offset)
239 results.append(FilterResult(1.0, point.payload))
240 else:
241 # Hybrid search (vector search + field filters)
242 search_query = parsed_query.text_query if parsed_query.text_query else query
243 query_embedding = await self.get_embedding(search_query)
245 search_params = models.SearchParams(
246 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search)
247 )
249 # Combine field filters with project filters
250 query_filter = self.field_parser.create_qdrant_filter(
251 parsed_query.field_queries, project_ids
252 )
254 results = await self.qdrant_client.search(
255 collection_name=self.collection_name,
256 query_vector=query_embedding,
257 limit=limit,
258 score_threshold=self.min_score,
259 search_params=search_params,
260 query_filter=query_filter,
261 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data
262 )
264 extracted_results = []
265 for hit in results:
266 extracted = {
267 "score": hit.score,
268 "text": hit.payload.get("content", "") if hit.payload else "",
269 "metadata": hit.payload.get("metadata", {}) if hit.payload else {},
270 "source_type": (
271 hit.payload.get("source_type", "unknown")
272 if hit.payload
273 else "unknown"
274 ),
275 # Extract fields directly from Qdrant payload
276 "title": hit.payload.get("title", "") if hit.payload else "",
277 "url": hit.payload.get("url", "") if hit.payload else "",
278 "document_id": (
279 hit.payload.get("document_id", "") if hit.payload else ""
280 ),
281 "source": hit.payload.get("source", "") if hit.payload else "",
282 "created_at": hit.payload.get("created_at", "") if hit.payload else "",
283 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "",
284 }
286 extracted_results.append(extracted)
288 # Store results in cache if caching is enabled
289 if self.cache_enabled:
290 async with self._cache_lock:
291 self._search_cache[cache_key] = {
292 "results": extracted_results,
293 "timestamp": time.time(),
294 }
296 self.logger.debug(
297 "Cached search results",
298 query=query[:50],
299 results_count=len(extracted_results),
300 cache_size=len(self._search_cache),
301 )
303 return extracted_results
305 def get_cache_stats(self) -> dict[str, Any]:
306 """Get cache performance statistics.
308 Returns:
309 Dictionary with cache hit rate, size, and other metrics
310 """
311 total_requests = self._cache_hits + self._cache_misses
312 hit_rate = (
313 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0
314 )
316 return {
317 "cache_enabled": self.cache_enabled,
318 "cache_hits": self._cache_hits,
319 "cache_misses": self._cache_misses,
320 "hit_rate_percent": round(hit_rate, 2),
321 "cache_size": len(self._search_cache),
322 "cache_max_size": self.cache_max_size,
323 "cache_ttl_seconds": self.cache_ttl,
324 }
326 def clear_cache(self) -> None:
327 """Clear all cached search results."""
328 self._search_cache.clear()
329 self.logger.info("Search result cache cleared")
331 def _build_filter(
332 self, project_ids: list[str] | None = None
333 ) -> models.Filter | None:
334 """Legacy method for backward compatibility - use FieldQueryParser instead.
336 Args:
337 project_ids: Optional project ID filters
339 Returns:
340 Qdrant Filter object or None
341 """
342 if project_ids:
343 from qdrant_client.http import models
345 return models.Filter(
346 must=[
347 models.FieldCondition(
348 key="project_id", match=models.MatchAny(any=project_ids)
349 )
350 ]
351 )
352 return None
354 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter()
356 def build_filter(
357 self, project_ids: list[str] | None = None
358 ) -> models.Filter | None:
359 """Public wrapper for building a Qdrant filter for project constraints.
361 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This
362 method exists to expose project filter building via a public API and wraps the
363 legacy `_build_filter` implementation for compatibility.
365 Args:
366 project_ids: Optional list of project IDs to filter by.
368 Returns:
369 A Qdrant `models.Filter` or `None` if no filtering is needed.
370 """
371 return self._build_filter(project_ids)