Coverage for src / qdrant_loader_mcp_server / search / components / vector_search_service.py: 94%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-10 09:41 +0000

1"""Vector search service for hybrid search.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import hashlib 

7import time 

8from asyncio import Lock 

9from dataclasses import dataclass 

10from typing import TYPE_CHECKING, Any 

11 

12if TYPE_CHECKING: 

13 from qdrant_client import AsyncQdrantClient 

14 from qdrant_client.http import models as qdrant_models 

15 

16from ...utils.logging import LoggingConfig 

17from .field_query_parser import FieldQueryParser 

18 

19 

20@dataclass 

21class FilterResult: 

22 score: float 

23 payload: dict 

24 

25 

26class VectorSearchService: 

27 """Handles vector search operations using Qdrant.""" 

28 

29 def __init__( 

30 self, 

31 qdrant_client: AsyncQdrantClient, 

32 collection_name: str, 

33 min_score: float = 0.3, 

34 cache_enabled: bool = True, 

35 cache_ttl: int = 300, 

36 cache_max_size: int = 500, 

37 hnsw_ef: int = 128, 

38 use_exact_search: bool = False, 

39 *, 

40 embeddings_provider: Any | None = None, 

41 openai_client: Any | None = None, 

42 embedding_model: str = "text-embedding-3-small", 

43 ): 

44 """Initialize the vector search service. 

45 

46 Args: 

47 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient) 

48 openai_client: OpenAI client instance 

49 collection_name: Name of the Qdrant collection 

50 min_score: Minimum score threshold 

51 cache_enabled: Whether to enable search result caching 

52 cache_ttl: Cache time-to-live in seconds 

53 cache_max_size: Maximum number of cached results 

54 """ 

55 self.qdrant_client = qdrant_client 

56 self.embeddings_provider = embeddings_provider 

57 self.openai_client = openai_client 

58 self.collection_name = collection_name 

59 self.embedding_model = embedding_model 

60 self.min_score = min_score 

61 

62 # Search result caching configuration 

63 self.cache_enabled = cache_enabled 

64 self.cache_ttl = cache_ttl 

65 self.cache_max_size = cache_max_size 

66 self._search_cache: dict[str, dict[str, Any]] = {} 

67 self._cache_lock: Lock = Lock() 

68 

69 # Cache performance metrics 

70 self._cache_hits = 0 

71 self._cache_misses = 0 

72 

73 # Field query parser for handling field:value syntax 

74 self.field_parser = FieldQueryParser() 

75 

76 self.logger = LoggingConfig.get_logger(__name__) 

77 

78 # Qdrant search parameters 

79 self.hnsw_ef = hnsw_ef 

80 self.use_exact_search = use_exact_search 

81 

82 def _generate_cache_key( 

83 self, query: str, limit: int, project_ids: list[str] | None = None 

84 ) -> str: 

85 """Generate a cache key for search parameters. 

86 

87 Args: 

88 query: Search query 

89 limit: Maximum number of results 

90 project_ids: Optional project ID filters 

91 

92 Returns: 

93 SHA256 hash of search parameters for cache key 

94 """ 

95 # Create a deterministic string from search parameters 

96 project_str = ",".join(sorted(project_ids)) if project_ids else "none" 

97 cache_input = ( 

98 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}" 

99 ) 

100 return hashlib.sha256(cache_input.encode()).hexdigest() 

101 

102 def _cleanup_expired_cache(self) -> None: 

103 """Remove expired entries from cache.""" 

104 if not self.cache_enabled: 

105 return 

106 

107 current_time = time.time() 

108 expired_keys = [ 

109 key 

110 for key, value in self._search_cache.items() 

111 if current_time - value["timestamp"] > self.cache_ttl 

112 ] 

113 

114 for key in expired_keys: 

115 del self._search_cache[key] 

116 

117 # Also enforce max size limit 

118 if len(self._search_cache) > self.cache_max_size: 

119 # Remove oldest entries (simple FIFO eviction) 

120 sorted_items = sorted( 

121 self._search_cache.items(), key=lambda x: x[1]["timestamp"] 

122 ) 

123 items_to_remove = len(self._search_cache) - self.cache_max_size 

124 for key, _ in sorted_items[:items_to_remove]: 

125 del self._search_cache[key] 

126 

127 async def get_embedding(self, text: str) -> list[float]: 

128 """Get embedding for text using OpenAI client when available, else provider. 

129 

130 Args: 

131 text: Text to get embedding for 

132 

133 Returns: 

134 List of embedding values 

135 

136 Raises: 

137 Exception: If embedding generation fails 

138 """ 

139 # Prefer provider when available 

140 if self.embeddings_provider is not None: 

141 # Accept either a provider (with .embeddings()) or a direct embeddings client 

142 client = ( 

143 self.embeddings_provider.embeddings() 

144 if hasattr(self.embeddings_provider, "embeddings") 

145 else self.embeddings_provider 

146 ) 

147 for _ in range(3): 

148 try: 

149 vectors = await client.embed([text]) 

150 return vectors[0] 

151 except Exception as e: 

152 self.logger.warning( 

153 "Provider embedding failed, retrying...", error=str(e) 

154 ) 

155 await asyncio.sleep(0.5) 

156 

157 # Fallback to OpenAI (to keep backward compatibility & pass tests) 

158 if self.openai_client is not None: 

159 try: 

160 response = await self.openai_client.embeddings.create( 

161 model=self.embedding_model, 

162 input=text, 

163 ) 

164 return response.data[0].embedding 

165 except Exception as e: 

166 self.logger.error("OpenAI fallback failed", error=str(e)) 

167 raise 

168 

169 # Nothing configured 

170 raise RuntimeError("No embeddings provider or OpenAI client configured") 

171 

172 async def vector_search( 

173 self, query: str, limit: int, project_ids: list[str] | None = None 

174 ) -> list[dict[str, Any]]: 

175 """Perform vector search using Qdrant with caching support. 

176 

177 Args: 

178 query: Search query 

179 limit: Maximum number of results 

180 project_ids: Optional project ID filters 

181 

182 Returns: 

183 List of search results with scores, text, metadata, and source_type 

184 """ 

185 # Generate cache key and check cache first 

186 cache_key = self._generate_cache_key(query, limit, project_ids) 

187 

188 if self.cache_enabled: 

189 # Guard cache reads/cleanup with the async lock 

190 async with self._cache_lock: 

191 self._cleanup_expired_cache() 

192 

193 # Check cache for existing results 

194 cached_entry = self._search_cache.get(cache_key) 

195 if cached_entry is not None: 

196 current_time = time.time() 

197 

198 # Verify cache entry is still valid 

199 if current_time - cached_entry["timestamp"] <= self.cache_ttl: 

200 self._cache_hits += 1 

201 self.logger.debug( 

202 "Search cache hit", 

203 query=query[:50], # Truncate for logging 

204 cache_hits=self._cache_hits, 

205 cache_misses=self._cache_misses, 

206 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%", 

207 ) 

208 return cached_entry["results"] 

209 

210 # Cache miss - perform actual search 

211 self._cache_misses += 1 

212 

213 self.logger.debug( 

214 "Search cache miss - performing QDrant search", 

215 query=query[:50], # Truncate for logging 

216 cache_hits=self._cache_hits, 

217 cache_misses=self._cache_misses, 

218 ) 

219 

220 # ✅ Parse query for field-specific filters 

221 parsed_query = self.field_parser.parse_query(query) 

222 self.logger.debug( 

223 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'" 

224 ) 

225 

226 # Determine search strategy based on parsed query 

227 if self.field_parser.should_use_filter_only(parsed_query): 

228 # Filter-only search (exact field matching) 

229 self.logger.debug("Using filter-only search for exact field matching") 

230 query_filter = self.field_parser.create_qdrant_filter( 

231 parsed_query.field_queries, project_ids 

232 ) 

233 

234 # For filter-only searches, use scroll with filter instead of vector search 

235 scroll_results = await self.qdrant_client.scroll( 

236 collection_name=self.collection_name, 

237 limit=limit, 

238 scroll_filter=query_filter, 

239 with_payload=True, 

240 with_vectors=False, 

241 ) 

242 

243 results = [] 

244 for point in scroll_results[ 

245 0 

246 ]: # scroll_results is (points, next_page_offset) 

247 results.append(FilterResult(1.0, point.payload)) 

248 else: 

249 # Hybrid search (vector search + field filters) 

250 from qdrant_client.http import models 

251 

252 search_query = parsed_query.text_query if parsed_query.text_query else query 

253 query_embedding = await self.get_embedding(search_query) 

254 

255 search_params = models.SearchParams( 

256 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search) 

257 ) 

258 

259 # Combine field filters with project filters 

260 query_filter = self.field_parser.create_qdrant_filter( 

261 parsed_query.field_queries, project_ids 

262 ) 

263 

264 # Use query_points API (qdrant-client 1.10+) 

265 query_response = await self.qdrant_client.query_points( 

266 collection_name=self.collection_name, 

267 query=query_embedding, 

268 limit=limit, 

269 score_threshold=self.min_score, 

270 search_params=search_params, 

271 query_filter=query_filter, 

272 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data 

273 ) 

274 results = query_response.points 

275 

276 extracted_results = [] 

277 for hit in results: 

278 extracted = { 

279 "score": hit.score, 

280 "text": hit.payload.get("content", "") if hit.payload else "", 

281 "metadata": hit.payload.get("metadata", {}) if hit.payload else {}, 

282 "source_type": ( 

283 hit.payload.get("source_type", "unknown") 

284 if hit.payload 

285 else "unknown" 

286 ), 

287 # Extract fields directly from Qdrant payload 

288 "title": hit.payload.get("title", "") if hit.payload else "", 

289 "url": hit.payload.get("url", "") if hit.payload else "", 

290 "document_id": ( 

291 hit.payload.get("document_id", "") if hit.payload else "" 

292 ), 

293 "source": hit.payload.get("source", "") if hit.payload else "", 

294 "created_at": hit.payload.get("created_at", "") if hit.payload else "", 

295 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "", 

296 "contextual_content": ( 

297 hit.payload.get("contextual_content", "") if hit.payload else "" 

298 ), 

299 } 

300 

301 extracted_results.append(extracted) 

302 

303 # Store results in cache if caching is enabled 

304 if self.cache_enabled: 

305 async with self._cache_lock: 

306 self._search_cache[cache_key] = { 

307 "results": extracted_results, 

308 "timestamp": time.time(), 

309 } 

310 

311 self.logger.debug( 

312 "Cached search results", 

313 query=query[:50], 

314 results_count=len(extracted_results), 

315 cache_size=len(self._search_cache), 

316 ) 

317 

318 return extracted_results 

319 

320 def get_cache_stats(self) -> dict[str, Any]: 

321 """Get cache performance statistics. 

322 

323 Returns: 

324 Dictionary with cache hit rate, size, and other metrics 

325 """ 

326 total_requests = self._cache_hits + self._cache_misses 

327 hit_rate = ( 

328 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0 

329 ) 

330 

331 return { 

332 "cache_enabled": self.cache_enabled, 

333 "cache_hits": self._cache_hits, 

334 "cache_misses": self._cache_misses, 

335 "hit_rate_percent": round(hit_rate, 2), 

336 "cache_size": len(self._search_cache), 

337 "cache_max_size": self.cache_max_size, 

338 "cache_ttl_seconds": self.cache_ttl, 

339 } 

340 

341 def clear_cache(self) -> None: 

342 """Clear all cached search results.""" 

343 self._search_cache.clear() 

344 self.logger.info("Search result cache cleared") 

345 

346 def _build_filter( 

347 self, project_ids: list[str] | None = None 

348 ) -> qdrant_models.Filter | None: 

349 """Legacy method for backward compatibility - use FieldQueryParser instead. 

350 

351 Args: 

352 project_ids: Optional project ID filters 

353 

354 Returns: 

355 Qdrant Filter object or None 

356 """ 

357 if project_ids: 

358 from qdrant_client.http import models 

359 

360 return models.Filter( 

361 must=[ 

362 models.FieldCondition( 

363 key="project_id", match=models.MatchAny(any=project_ids) 

364 ) 

365 ] 

366 ) 

367 return None 

368 

369 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter() 

370 

371 def build_filter( 

372 self, project_ids: list[str] | None = None 

373 ) -> qdrant_models.Filter | None: 

374 """Public wrapper for building a Qdrant filter for project constraints. 

375 

376 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This 

377 method exists to expose project filter building via a public API and wraps the 

378 legacy `_build_filter` implementation for compatibility. 

379 

380 Args: 

381 project_ids: Optional list of project IDs to filter by. 

382 

383 Returns: 

384 A Qdrant `models.Filter` or `None` if no filtering is needed. 

385 """ 

386 return self._build_filter(project_ids)