Coverage for src / qdrant_loader_mcp_server / search / components / vector_search_service.py: 94%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 04:51 +0000

1"""Vector search service for hybrid search.""" 

2 

3from __future__ import annotations 

4 

5import hashlib 

6import time 

7from asyncio import Lock 

8from dataclasses import dataclass 

9from typing import TYPE_CHECKING, Any 

10 

11if TYPE_CHECKING: 

12 from qdrant_client import AsyncQdrantClient 

13 from qdrant_client.http import models as qdrant_models 

14 

15from ...utils.logging import LoggingConfig 

16from .field_query_parser import FieldQueryParser 

17 

18 

19@dataclass 

20class FilterResult: 

21 score: float 

22 payload: dict 

23 

24 

25class VectorSearchService: 

26 """Handles vector search operations using Qdrant.""" 

27 

28 def __init__( 

29 self, 

30 qdrant_client: AsyncQdrantClient, 

31 collection_name: str, 

32 min_score: float = 0.3, 

33 cache_enabled: bool = True, 

34 cache_ttl: int = 300, 

35 cache_max_size: int = 500, 

36 hnsw_ef: int = 128, 

37 use_exact_search: bool = False, 

38 *, 

39 embeddings_provider: Any | None = None, 

40 openai_client: Any | None = None, 

41 ): 

42 """Initialize the vector search service. 

43 

44 Args: 

45 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient) 

46 openai_client: OpenAI client instance 

47 collection_name: Name of the Qdrant collection 

48 min_score: Minimum score threshold 

49 cache_enabled: Whether to enable search result caching 

50 cache_ttl: Cache time-to-live in seconds 

51 cache_max_size: Maximum number of cached results 

52 """ 

53 self.qdrant_client = qdrant_client 

54 self.embeddings_provider = embeddings_provider 

55 self.openai_client = openai_client 

56 self.collection_name = collection_name 

57 self.min_score = min_score 

58 

59 # Search result caching configuration 

60 self.cache_enabled = cache_enabled 

61 self.cache_ttl = cache_ttl 

62 self.cache_max_size = cache_max_size 

63 self._search_cache: dict[str, dict[str, Any]] = {} 

64 self._cache_lock: Lock = Lock() 

65 

66 # Cache performance metrics 

67 self._cache_hits = 0 

68 self._cache_misses = 0 

69 

70 # Field query parser for handling field:value syntax 

71 self.field_parser = FieldQueryParser() 

72 

73 self.logger = LoggingConfig.get_logger(__name__) 

74 

75 # Qdrant search parameters 

76 self.hnsw_ef = hnsw_ef 

77 self.use_exact_search = use_exact_search 

78 

79 def _generate_cache_key( 

80 self, query: str, limit: int, project_ids: list[str] | None = None 

81 ) -> str: 

82 """Generate a cache key for search parameters. 

83 

84 Args: 

85 query: Search query 

86 limit: Maximum number of results 

87 project_ids: Optional project ID filters 

88 

89 Returns: 

90 SHA256 hash of search parameters for cache key 

91 """ 

92 # Create a deterministic string from search parameters 

93 project_str = ",".join(sorted(project_ids)) if project_ids else "none" 

94 cache_input = ( 

95 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}" 

96 ) 

97 return hashlib.sha256(cache_input.encode()).hexdigest() 

98 

99 def _cleanup_expired_cache(self) -> None: 

100 """Remove expired entries from cache.""" 

101 if not self.cache_enabled: 

102 return 

103 

104 current_time = time.time() 

105 expired_keys = [ 

106 key 

107 for key, value in self._search_cache.items() 

108 if current_time - value["timestamp"] > self.cache_ttl 

109 ] 

110 

111 for key in expired_keys: 

112 del self._search_cache[key] 

113 

114 # Also enforce max size limit 

115 if len(self._search_cache) > self.cache_max_size: 

116 # Remove oldest entries (simple FIFO eviction) 

117 sorted_items = sorted( 

118 self._search_cache.items(), key=lambda x: x[1]["timestamp"] 

119 ) 

120 items_to_remove = len(self._search_cache) - self.cache_max_size 

121 for key, _ in sorted_items[:items_to_remove]: 

122 del self._search_cache[key] 

123 

124 async def get_embedding(self, text: str) -> list[float]: 

125 """Get embedding for text using OpenAI client when available, else provider. 

126 

127 Args: 

128 text: Text to get embedding for 

129 

130 Returns: 

131 List of embedding values 

132 

133 Raises: 

134 Exception: If embedding generation fails 

135 """ 

136 # Prefer provider when available 

137 if self.embeddings_provider is not None: 

138 try: 

139 # Accept either a provider (with .embeddings()) or a direct embeddings client 

140 client = ( 

141 self.embeddings_provider.embeddings() 

142 if hasattr(self.embeddings_provider, "embeddings") 

143 else self.embeddings_provider 

144 ) 

145 vectors = await client.embed([text]) 

146 return vectors[0] 

147 except Exception as e: 

148 self.logger.error("Provider embeddings failed", error=str(e)) 

149 raise 

150 

151 # Fallback to OpenAI client when provider is not configured 

152 if self.openai_client is not None: 

153 try: 

154 response = await self.openai_client.embeddings.create( 

155 model="text-embedding-3-small", 

156 input=text, 

157 ) 

158 return response.data[0].embedding 

159 except Exception as e: 

160 # Do not fall back silently; propagate error as tests expect 

161 self.logger.error("Failed to get embedding", error=str(e)) 

162 raise 

163 

164 # Nothing configured 

165 raise RuntimeError("No embeddings provider or OpenAI client configured") 

166 

167 async def vector_search( 

168 self, query: str, limit: int, project_ids: list[str] | None = None 

169 ) -> list[dict[str, Any]]: 

170 """Perform vector search using Qdrant with caching support. 

171 

172 Args: 

173 query: Search query 

174 limit: Maximum number of results 

175 project_ids: Optional project ID filters 

176 

177 Returns: 

178 List of search results with scores, text, metadata, and source_type 

179 """ 

180 # Generate cache key and check cache first 

181 cache_key = self._generate_cache_key(query, limit, project_ids) 

182 

183 if self.cache_enabled: 

184 # Guard cache reads/cleanup with the async lock 

185 async with self._cache_lock: 

186 self._cleanup_expired_cache() 

187 

188 # Check cache for existing results 

189 cached_entry = self._search_cache.get(cache_key) 

190 if cached_entry is not None: 

191 current_time = time.time() 

192 

193 # Verify cache entry is still valid 

194 if current_time - cached_entry["timestamp"] <= self.cache_ttl: 

195 self._cache_hits += 1 

196 self.logger.debug( 

197 "Search cache hit", 

198 query=query[:50], # Truncate for logging 

199 cache_hits=self._cache_hits, 

200 cache_misses=self._cache_misses, 

201 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%", 

202 ) 

203 return cached_entry["results"] 

204 

205 # Cache miss - perform actual search 

206 self._cache_misses += 1 

207 

208 self.logger.debug( 

209 "Search cache miss - performing QDrant search", 

210 query=query[:50], # Truncate for logging 

211 cache_hits=self._cache_hits, 

212 cache_misses=self._cache_misses, 

213 ) 

214 

215 # ✅ Parse query for field-specific filters 

216 parsed_query = self.field_parser.parse_query(query) 

217 self.logger.debug( 

218 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'" 

219 ) 

220 

221 # Determine search strategy based on parsed query 

222 if self.field_parser.should_use_filter_only(parsed_query): 

223 # Filter-only search (exact field matching) 

224 self.logger.debug("Using filter-only search for exact field matching") 

225 query_filter = self.field_parser.create_qdrant_filter( 

226 parsed_query.field_queries, project_ids 

227 ) 

228 

229 # For filter-only searches, use scroll with filter instead of vector search 

230 scroll_results = await self.qdrant_client.scroll( 

231 collection_name=self.collection_name, 

232 limit=limit, 

233 scroll_filter=query_filter, 

234 with_payload=True, 

235 with_vectors=False, 

236 ) 

237 

238 results = [] 

239 for point in scroll_results[ 

240 0 

241 ]: # scroll_results is (points, next_page_offset) 

242 results.append(FilterResult(1.0, point.payload)) 

243 else: 

244 # Hybrid search (vector search + field filters) 

245 from qdrant_client.http import models 

246 

247 search_query = parsed_query.text_query if parsed_query.text_query else query 

248 query_embedding = await self.get_embedding(search_query) 

249 

250 search_params = models.SearchParams( 

251 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search) 

252 ) 

253 

254 # Combine field filters with project filters 

255 query_filter = self.field_parser.create_qdrant_filter( 

256 parsed_query.field_queries, project_ids 

257 ) 

258 

259 # Use query_points API (qdrant-client 1.10+) 

260 query_response = await self.qdrant_client.query_points( 

261 collection_name=self.collection_name, 

262 query=query_embedding, 

263 limit=limit, 

264 score_threshold=self.min_score, 

265 search_params=search_params, 

266 query_filter=query_filter, 

267 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data 

268 ) 

269 results = query_response.points 

270 

271 extracted_results = [] 

272 for hit in results: 

273 extracted = { 

274 "score": hit.score, 

275 "text": hit.payload.get("content", "") if hit.payload else "", 

276 "metadata": hit.payload.get("metadata", {}) if hit.payload else {}, 

277 "source_type": ( 

278 hit.payload.get("source_type", "unknown") 

279 if hit.payload 

280 else "unknown" 

281 ), 

282 # Extract fields directly from Qdrant payload 

283 "title": hit.payload.get("title", "") if hit.payload else "", 

284 "url": hit.payload.get("url", "") if hit.payload else "", 

285 "document_id": ( 

286 hit.payload.get("document_id", "") if hit.payload else "" 

287 ), 

288 "source": hit.payload.get("source", "") if hit.payload else "", 

289 "created_at": hit.payload.get("created_at", "") if hit.payload else "", 

290 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "", 

291 } 

292 

293 extracted_results.append(extracted) 

294 

295 # Store results in cache if caching is enabled 

296 if self.cache_enabled: 

297 async with self._cache_lock: 

298 self._search_cache[cache_key] = { 

299 "results": extracted_results, 

300 "timestamp": time.time(), 

301 } 

302 

303 self.logger.debug( 

304 "Cached search results", 

305 query=query[:50], 

306 results_count=len(extracted_results), 

307 cache_size=len(self._search_cache), 

308 ) 

309 

310 return extracted_results 

311 

312 def get_cache_stats(self) -> dict[str, Any]: 

313 """Get cache performance statistics. 

314 

315 Returns: 

316 Dictionary with cache hit rate, size, and other metrics 

317 """ 

318 total_requests = self._cache_hits + self._cache_misses 

319 hit_rate = ( 

320 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0 

321 ) 

322 

323 return { 

324 "cache_enabled": self.cache_enabled, 

325 "cache_hits": self._cache_hits, 

326 "cache_misses": self._cache_misses, 

327 "hit_rate_percent": round(hit_rate, 2), 

328 "cache_size": len(self._search_cache), 

329 "cache_max_size": self.cache_max_size, 

330 "cache_ttl_seconds": self.cache_ttl, 

331 } 

332 

333 def clear_cache(self) -> None: 

334 """Clear all cached search results.""" 

335 self._search_cache.clear() 

336 self.logger.info("Search result cache cleared") 

337 

338 def _build_filter( 

339 self, project_ids: list[str] | None = None 

340 ) -> qdrant_models.Filter | None: 

341 """Legacy method for backward compatibility - use FieldQueryParser instead. 

342 

343 Args: 

344 project_ids: Optional project ID filters 

345 

346 Returns: 

347 Qdrant Filter object or None 

348 """ 

349 if project_ids: 

350 from qdrant_client.http import models 

351 

352 return models.Filter( 

353 must=[ 

354 models.FieldCondition( 

355 key="project_id", match=models.MatchAny(any=project_ids) 

356 ) 

357 ] 

358 ) 

359 return None 

360 

361 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter() 

362 

363 def build_filter( 

364 self, project_ids: list[str] | None = None 

365 ) -> qdrant_models.Filter | None: 

366 """Public wrapper for building a Qdrant filter for project constraints. 

367 

368 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This 

369 method exists to expose project filter building via a public API and wraps the 

370 legacy `_build_filter` implementation for compatibility. 

371 

372 Args: 

373 project_ids: Optional list of project IDs to filter by. 

374 

375 Returns: 

376 A Qdrant `models.Filter` or `None` if no filtering is needed. 

377 """ 

378 return self._build_filter(project_ids)