Coverage for src / qdrant_loader_mcp_server / search / components / vector_search_service.py: 94%
119 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:41 +0000
1"""Vector search service for hybrid search."""
3from __future__ import annotations
5import asyncio
6import hashlib
7import time
8from asyncio import Lock
9from dataclasses import dataclass
10from typing import TYPE_CHECKING, Any
12if TYPE_CHECKING:
13 from qdrant_client import AsyncQdrantClient
14 from qdrant_client.http import models as qdrant_models
16from ...utils.logging import LoggingConfig
17from .field_query_parser import FieldQueryParser
20@dataclass
21class FilterResult:
22 score: float
23 payload: dict
26class VectorSearchService:
27 """Handles vector search operations using Qdrant."""
29 def __init__(
30 self,
31 qdrant_client: AsyncQdrantClient,
32 collection_name: str,
33 min_score: float = 0.3,
34 cache_enabled: bool = True,
35 cache_ttl: int = 300,
36 cache_max_size: int = 500,
37 hnsw_ef: int = 128,
38 use_exact_search: bool = False,
39 *,
40 embeddings_provider: Any | None = None,
41 openai_client: Any | None = None,
42 embedding_model: str = "text-embedding-3-small",
43 ):
44 """Initialize the vector search service.
46 Args:
47 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient)
48 openai_client: OpenAI client instance
49 collection_name: Name of the Qdrant collection
50 min_score: Minimum score threshold
51 cache_enabled: Whether to enable search result caching
52 cache_ttl: Cache time-to-live in seconds
53 cache_max_size: Maximum number of cached results
54 """
55 self.qdrant_client = qdrant_client
56 self.embeddings_provider = embeddings_provider
57 self.openai_client = openai_client
58 self.collection_name = collection_name
59 self.embedding_model = embedding_model
60 self.min_score = min_score
62 # Search result caching configuration
63 self.cache_enabled = cache_enabled
64 self.cache_ttl = cache_ttl
65 self.cache_max_size = cache_max_size
66 self._search_cache: dict[str, dict[str, Any]] = {}
67 self._cache_lock: Lock = Lock()
69 # Cache performance metrics
70 self._cache_hits = 0
71 self._cache_misses = 0
73 # Field query parser for handling field:value syntax
74 self.field_parser = FieldQueryParser()
76 self.logger = LoggingConfig.get_logger(__name__)
78 # Qdrant search parameters
79 self.hnsw_ef = hnsw_ef
80 self.use_exact_search = use_exact_search
82 def _generate_cache_key(
83 self, query: str, limit: int, project_ids: list[str] | None = None
84 ) -> str:
85 """Generate a cache key for search parameters.
87 Args:
88 query: Search query
89 limit: Maximum number of results
90 project_ids: Optional project ID filters
92 Returns:
93 SHA256 hash of search parameters for cache key
94 """
95 # Create a deterministic string from search parameters
96 project_str = ",".join(sorted(project_ids)) if project_ids else "none"
97 cache_input = (
98 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}"
99 )
100 return hashlib.sha256(cache_input.encode()).hexdigest()
102 def _cleanup_expired_cache(self) -> None:
103 """Remove expired entries from cache."""
104 if not self.cache_enabled:
105 return
107 current_time = time.time()
108 expired_keys = [
109 key
110 for key, value in self._search_cache.items()
111 if current_time - value["timestamp"] > self.cache_ttl
112 ]
114 for key in expired_keys:
115 del self._search_cache[key]
117 # Also enforce max size limit
118 if len(self._search_cache) > self.cache_max_size:
119 # Remove oldest entries (simple FIFO eviction)
120 sorted_items = sorted(
121 self._search_cache.items(), key=lambda x: x[1]["timestamp"]
122 )
123 items_to_remove = len(self._search_cache) - self.cache_max_size
124 for key, _ in sorted_items[:items_to_remove]:
125 del self._search_cache[key]
127 async def get_embedding(self, text: str) -> list[float]:
128 """Get embedding for text using OpenAI client when available, else provider.
130 Args:
131 text: Text to get embedding for
133 Returns:
134 List of embedding values
136 Raises:
137 Exception: If embedding generation fails
138 """
139 # Prefer provider when available
140 if self.embeddings_provider is not None:
141 # Accept either a provider (with .embeddings()) or a direct embeddings client
142 client = (
143 self.embeddings_provider.embeddings()
144 if hasattr(self.embeddings_provider, "embeddings")
145 else self.embeddings_provider
146 )
147 for _ in range(3):
148 try:
149 vectors = await client.embed([text])
150 return vectors[0]
151 except Exception as e:
152 self.logger.warning(
153 "Provider embedding failed, retrying...", error=str(e)
154 )
155 await asyncio.sleep(0.5)
157 # Fallback to OpenAI (to keep backward compatibility & pass tests)
158 if self.openai_client is not None:
159 try:
160 response = await self.openai_client.embeddings.create(
161 model=self.embedding_model,
162 input=text,
163 )
164 return response.data[0].embedding
165 except Exception as e:
166 self.logger.error("OpenAI fallback failed", error=str(e))
167 raise
169 # Nothing configured
170 raise RuntimeError("No embeddings provider or OpenAI client configured")
172 async def vector_search(
173 self, query: str, limit: int, project_ids: list[str] | None = None
174 ) -> list[dict[str, Any]]:
175 """Perform vector search using Qdrant with caching support.
177 Args:
178 query: Search query
179 limit: Maximum number of results
180 project_ids: Optional project ID filters
182 Returns:
183 List of search results with scores, text, metadata, and source_type
184 """
185 # Generate cache key and check cache first
186 cache_key = self._generate_cache_key(query, limit, project_ids)
188 if self.cache_enabled:
189 # Guard cache reads/cleanup with the async lock
190 async with self._cache_lock:
191 self._cleanup_expired_cache()
193 # Check cache for existing results
194 cached_entry = self._search_cache.get(cache_key)
195 if cached_entry is not None:
196 current_time = time.time()
198 # Verify cache entry is still valid
199 if current_time - cached_entry["timestamp"] <= self.cache_ttl:
200 self._cache_hits += 1
201 self.logger.debug(
202 "Search cache hit",
203 query=query[:50], # Truncate for logging
204 cache_hits=self._cache_hits,
205 cache_misses=self._cache_misses,
206 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%",
207 )
208 return cached_entry["results"]
210 # Cache miss - perform actual search
211 self._cache_misses += 1
213 self.logger.debug(
214 "Search cache miss - performing QDrant search",
215 query=query[:50], # Truncate for logging
216 cache_hits=self._cache_hits,
217 cache_misses=self._cache_misses,
218 )
220 # ✅ Parse query for field-specific filters
221 parsed_query = self.field_parser.parse_query(query)
222 self.logger.debug(
223 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'"
224 )
226 # Determine search strategy based on parsed query
227 if self.field_parser.should_use_filter_only(parsed_query):
228 # Filter-only search (exact field matching)
229 self.logger.debug("Using filter-only search for exact field matching")
230 query_filter = self.field_parser.create_qdrant_filter(
231 parsed_query.field_queries, project_ids
232 )
234 # For filter-only searches, use scroll with filter instead of vector search
235 scroll_results = await self.qdrant_client.scroll(
236 collection_name=self.collection_name,
237 limit=limit,
238 scroll_filter=query_filter,
239 with_payload=True,
240 with_vectors=False,
241 )
243 results = []
244 for point in scroll_results[
245 0
246 ]: # scroll_results is (points, next_page_offset)
247 results.append(FilterResult(1.0, point.payload))
248 else:
249 # Hybrid search (vector search + field filters)
250 from qdrant_client.http import models
252 search_query = parsed_query.text_query if parsed_query.text_query else query
253 query_embedding = await self.get_embedding(search_query)
255 search_params = models.SearchParams(
256 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search)
257 )
259 # Combine field filters with project filters
260 query_filter = self.field_parser.create_qdrant_filter(
261 parsed_query.field_queries, project_ids
262 )
264 # Use query_points API (qdrant-client 1.10+)
265 query_response = await self.qdrant_client.query_points(
266 collection_name=self.collection_name,
267 query=query_embedding,
268 limit=limit,
269 score_threshold=self.min_score,
270 search_params=search_params,
271 query_filter=query_filter,
272 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data
273 )
274 results = query_response.points
276 extracted_results = []
277 for hit in results:
278 extracted = {
279 "score": hit.score,
280 "text": hit.payload.get("content", "") if hit.payload else "",
281 "metadata": hit.payload.get("metadata", {}) if hit.payload else {},
282 "source_type": (
283 hit.payload.get("source_type", "unknown")
284 if hit.payload
285 else "unknown"
286 ),
287 # Extract fields directly from Qdrant payload
288 "title": hit.payload.get("title", "") if hit.payload else "",
289 "url": hit.payload.get("url", "") if hit.payload else "",
290 "document_id": (
291 hit.payload.get("document_id", "") if hit.payload else ""
292 ),
293 "source": hit.payload.get("source", "") if hit.payload else "",
294 "created_at": hit.payload.get("created_at", "") if hit.payload else "",
295 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "",
296 "contextual_content": (
297 hit.payload.get("contextual_content", "") if hit.payload else ""
298 ),
299 }
301 extracted_results.append(extracted)
303 # Store results in cache if caching is enabled
304 if self.cache_enabled:
305 async with self._cache_lock:
306 self._search_cache[cache_key] = {
307 "results": extracted_results,
308 "timestamp": time.time(),
309 }
311 self.logger.debug(
312 "Cached search results",
313 query=query[:50],
314 results_count=len(extracted_results),
315 cache_size=len(self._search_cache),
316 )
318 return extracted_results
320 def get_cache_stats(self) -> dict[str, Any]:
321 """Get cache performance statistics.
323 Returns:
324 Dictionary with cache hit rate, size, and other metrics
325 """
326 total_requests = self._cache_hits + self._cache_misses
327 hit_rate = (
328 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0
329 )
331 return {
332 "cache_enabled": self.cache_enabled,
333 "cache_hits": self._cache_hits,
334 "cache_misses": self._cache_misses,
335 "hit_rate_percent": round(hit_rate, 2),
336 "cache_size": len(self._search_cache),
337 "cache_max_size": self.cache_max_size,
338 "cache_ttl_seconds": self.cache_ttl,
339 }
341 def clear_cache(self) -> None:
342 """Clear all cached search results."""
343 self._search_cache.clear()
344 self.logger.info("Search result cache cleared")
346 def _build_filter(
347 self, project_ids: list[str] | None = None
348 ) -> qdrant_models.Filter | None:
349 """Legacy method for backward compatibility - use FieldQueryParser instead.
351 Args:
352 project_ids: Optional project ID filters
354 Returns:
355 Qdrant Filter object or None
356 """
357 if project_ids:
358 from qdrant_client.http import models
360 return models.Filter(
361 must=[
362 models.FieldCondition(
363 key="project_id", match=models.MatchAny(any=project_ids)
364 )
365 ]
366 )
367 return None
369 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter()
371 def build_filter(
372 self, project_ids: list[str] | None = None
373 ) -> qdrant_models.Filter | None:
374 """Public wrapper for building a Qdrant filter for project constraints.
376 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This
377 method exists to expose project filter building via a public API and wraps the
378 legacy `_build_filter` implementation for compatibility.
380 Args:
381 project_ids: Optional list of project IDs to filter by.
383 Returns:
384 A Qdrant `models.Filter` or `None` if no filtering is needed.
385 """
386 return self._build_filter(project_ids)