Coverage for src/qdrant_loader_mcp_server/search/components/vector_search_service.py: 94%

115 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1"""Vector search service for hybrid search.""" 

2 

3import hashlib 

4import time 

5from asyncio import Lock 

6from dataclasses import dataclass 

7from typing import Any 

8 

9from qdrant_client import AsyncQdrantClient 

10from qdrant_client.http import models 

11 

12from ...utils.logging import LoggingConfig 

13from .field_query_parser import FieldQueryParser 

14 

15 

16@dataclass 

17class FilterResult: 

18 score: float 

19 payload: dict 

20 

21 

22class VectorSearchService: 

23 """Handles vector search operations using Qdrant.""" 

24 

25 def __init__( 

26 self, 

27 qdrant_client: AsyncQdrantClient, 

28 collection_name: str, 

29 min_score: float = 0.3, 

30 cache_enabled: bool = True, 

31 cache_ttl: int = 300, 

32 cache_max_size: int = 500, 

33 hnsw_ef: int = 128, 

34 use_exact_search: bool = False, 

35 *, 

36 embeddings_provider: Any | None = None, 

37 openai_client: Any | None = None, 

38 ): 

39 """Initialize the vector search service. 

40 

41 Args: 

42 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient) 

43 openai_client: OpenAI client instance 

44 collection_name: Name of the Qdrant collection 

45 min_score: Minimum score threshold 

46 cache_enabled: Whether to enable search result caching 

47 cache_ttl: Cache time-to-live in seconds 

48 cache_max_size: Maximum number of cached results 

49 """ 

50 self.qdrant_client = qdrant_client 

51 self.embeddings_provider = embeddings_provider 

52 self.openai_client = openai_client 

53 self.collection_name = collection_name 

54 self.min_score = min_score 

55 

56 # Search result caching configuration 

57 self.cache_enabled = cache_enabled 

58 self.cache_ttl = cache_ttl 

59 self.cache_max_size = cache_max_size 

60 self._search_cache: dict[str, dict[str, Any]] = {} 

61 self._cache_lock: Lock = Lock() 

62 

63 # Cache performance metrics 

64 self._cache_hits = 0 

65 self._cache_misses = 0 

66 

67 # Field query parser for handling field:value syntax 

68 self.field_parser = FieldQueryParser() 

69 

70 self.logger = LoggingConfig.get_logger(__name__) 

71 

72 # Qdrant search parameters 

73 self.hnsw_ef = hnsw_ef 

74 self.use_exact_search = use_exact_search 

75 

76 def _generate_cache_key( 

77 self, query: str, limit: int, project_ids: list[str] | None = None 

78 ) -> str: 

79 """Generate a cache key for search parameters. 

80 

81 Args: 

82 query: Search query 

83 limit: Maximum number of results 

84 project_ids: Optional project ID filters 

85 

86 Returns: 

87 SHA256 hash of search parameters for cache key 

88 """ 

89 # Create a deterministic string from search parameters 

90 project_str = ",".join(sorted(project_ids)) if project_ids else "none" 

91 cache_input = ( 

92 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}" 

93 ) 

94 return hashlib.sha256(cache_input.encode()).hexdigest() 

95 

96 def _cleanup_expired_cache(self) -> None: 

97 """Remove expired entries from cache.""" 

98 if not self.cache_enabled: 

99 return 

100 

101 current_time = time.time() 

102 expired_keys = [ 

103 key 

104 for key, value in self._search_cache.items() 

105 if current_time - value["timestamp"] > self.cache_ttl 

106 ] 

107 

108 for key in expired_keys: 

109 del self._search_cache[key] 

110 

111 # Also enforce max size limit 

112 if len(self._search_cache) > self.cache_max_size: 

113 # Remove oldest entries (simple FIFO eviction) 

114 sorted_items = sorted( 

115 self._search_cache.items(), key=lambda x: x[1]["timestamp"] 

116 ) 

117 items_to_remove = len(self._search_cache) - self.cache_max_size 

118 for key, _ in sorted_items[:items_to_remove]: 

119 del self._search_cache[key] 

120 

121 async def get_embedding(self, text: str) -> list[float]: 

122 """Get embedding for text using OpenAI client when available, else provider. 

123 

124 Args: 

125 text: Text to get embedding for 

126 

127 Returns: 

128 List of embedding values 

129 

130 Raises: 

131 Exception: If embedding generation fails 

132 """ 

133 # Prefer provider when available 

134 if self.embeddings_provider is not None: 

135 try: 

136 # Accept either a provider (with .embeddings()) or a direct embeddings client 

137 client = ( 

138 self.embeddings_provider.embeddings() 

139 if hasattr(self.embeddings_provider, "embeddings") 

140 else self.embeddings_provider 

141 ) 

142 vectors = await client.embed([text]) 

143 return vectors[0] 

144 except Exception as e: 

145 self.logger.error("Provider embeddings failed", error=str(e)) 

146 raise 

147 

148 # Fallback to OpenAI client when provider is not configured 

149 if self.openai_client is not None: 

150 try: 

151 response = await self.openai_client.embeddings.create( 

152 model="text-embedding-3-small", 

153 input=text, 

154 ) 

155 return response.data[0].embedding 

156 except Exception as e: 

157 # Do not fall back silently; propagate error as tests expect 

158 self.logger.error("Failed to get embedding", error=str(e)) 

159 raise 

160 

161 # Nothing configured 

162 raise RuntimeError("No embeddings provider or OpenAI client configured") 

163 

164 async def vector_search( 

165 self, query: str, limit: int, project_ids: list[str] | None = None 

166 ) -> list[dict[str, Any]]: 

167 """Perform vector search using Qdrant with caching support. 

168 

169 Args: 

170 query: Search query 

171 limit: Maximum number of results 

172 project_ids: Optional project ID filters 

173 

174 Returns: 

175 List of search results with scores, text, metadata, and source_type 

176 """ 

177 # Generate cache key and check cache first 

178 cache_key = self._generate_cache_key(query, limit, project_ids) 

179 

180 if self.cache_enabled: 

181 # Guard cache reads/cleanup with the async lock 

182 async with self._cache_lock: 

183 self._cleanup_expired_cache() 

184 

185 # Check cache for existing results 

186 cached_entry = self._search_cache.get(cache_key) 

187 if cached_entry is not None: 

188 current_time = time.time() 

189 

190 # Verify cache entry is still valid 

191 if current_time - cached_entry["timestamp"] <= self.cache_ttl: 

192 self._cache_hits += 1 

193 self.logger.debug( 

194 "Search cache hit", 

195 query=query[:50], # Truncate for logging 

196 cache_hits=self._cache_hits, 

197 cache_misses=self._cache_misses, 

198 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%", 

199 ) 

200 return cached_entry["results"] 

201 

202 # Cache miss - perform actual search 

203 self._cache_misses += 1 

204 

205 self.logger.debug( 

206 "Search cache miss - performing QDrant search", 

207 query=query[:50], # Truncate for logging 

208 cache_hits=self._cache_hits, 

209 cache_misses=self._cache_misses, 

210 ) 

211 

212 # ✅ Parse query for field-specific filters 

213 parsed_query = self.field_parser.parse_query(query) 

214 self.logger.debug( 

215 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'" 

216 ) 

217 

218 # Determine search strategy based on parsed query 

219 if self.field_parser.should_use_filter_only(parsed_query): 

220 # Filter-only search (exact field matching) 

221 self.logger.debug("Using filter-only search for exact field matching") 

222 query_filter = self.field_parser.create_qdrant_filter( 

223 parsed_query.field_queries, project_ids 

224 ) 

225 

226 # For filter-only searches, use scroll with filter instead of vector search 

227 scroll_results = await self.qdrant_client.scroll( 

228 collection_name=self.collection_name, 

229 limit=limit, 

230 scroll_filter=query_filter, 

231 with_payload=True, 

232 with_vectors=False, 

233 ) 

234 

235 results = [] 

236 for point in scroll_results[ 

237 0 

238 ]: # scroll_results is (points, next_page_offset) 

239 results.append(FilterResult(1.0, point.payload)) 

240 else: 

241 # Hybrid search (vector search + field filters) 

242 search_query = parsed_query.text_query if parsed_query.text_query else query 

243 query_embedding = await self.get_embedding(search_query) 

244 

245 search_params = models.SearchParams( 

246 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search) 

247 ) 

248 

249 # Combine field filters with project filters 

250 query_filter = self.field_parser.create_qdrant_filter( 

251 parsed_query.field_queries, project_ids 

252 ) 

253 

254 results = await self.qdrant_client.search( 

255 collection_name=self.collection_name, 

256 query_vector=query_embedding, 

257 limit=limit, 

258 score_threshold=self.min_score, 

259 search_params=search_params, 

260 query_filter=query_filter, 

261 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data 

262 ) 

263 

264 extracted_results = [] 

265 for hit in results: 

266 extracted = { 

267 "score": hit.score, 

268 "text": hit.payload.get("content", "") if hit.payload else "", 

269 "metadata": hit.payload.get("metadata", {}) if hit.payload else {}, 

270 "source_type": ( 

271 hit.payload.get("source_type", "unknown") 

272 if hit.payload 

273 else "unknown" 

274 ), 

275 # Extract fields directly from Qdrant payload 

276 "title": hit.payload.get("title", "") if hit.payload else "", 

277 "url": hit.payload.get("url", "") if hit.payload else "", 

278 "document_id": ( 

279 hit.payload.get("document_id", "") if hit.payload else "" 

280 ), 

281 "source": hit.payload.get("source", "") if hit.payload else "", 

282 "created_at": hit.payload.get("created_at", "") if hit.payload else "", 

283 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "", 

284 } 

285 

286 extracted_results.append(extracted) 

287 

288 # Store results in cache if caching is enabled 

289 if self.cache_enabled: 

290 async with self._cache_lock: 

291 self._search_cache[cache_key] = { 

292 "results": extracted_results, 

293 "timestamp": time.time(), 

294 } 

295 

296 self.logger.debug( 

297 "Cached search results", 

298 query=query[:50], 

299 results_count=len(extracted_results), 

300 cache_size=len(self._search_cache), 

301 ) 

302 

303 return extracted_results 

304 

305 def get_cache_stats(self) -> dict[str, Any]: 

306 """Get cache performance statistics. 

307 

308 Returns: 

309 Dictionary with cache hit rate, size, and other metrics 

310 """ 

311 total_requests = self._cache_hits + self._cache_misses 

312 hit_rate = ( 

313 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0 

314 ) 

315 

316 return { 

317 "cache_enabled": self.cache_enabled, 

318 "cache_hits": self._cache_hits, 

319 "cache_misses": self._cache_misses, 

320 "hit_rate_percent": round(hit_rate, 2), 

321 "cache_size": len(self._search_cache), 

322 "cache_max_size": self.cache_max_size, 

323 "cache_ttl_seconds": self.cache_ttl, 

324 } 

325 

326 def clear_cache(self) -> None: 

327 """Clear all cached search results.""" 

328 self._search_cache.clear() 

329 self.logger.info("Search result cache cleared") 

330 

331 def _build_filter( 

332 self, project_ids: list[str] | None = None 

333 ) -> models.Filter | None: 

334 """Legacy method for backward compatibility - use FieldQueryParser instead. 

335 

336 Args: 

337 project_ids: Optional project ID filters 

338 

339 Returns: 

340 Qdrant Filter object or None 

341 """ 

342 if project_ids: 

343 from qdrant_client.http import models 

344 

345 return models.Filter( 

346 must=[ 

347 models.FieldCondition( 

348 key="project_id", match=models.MatchAny(any=project_ids) 

349 ) 

350 ] 

351 ) 

352 return None 

353 

354 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter() 

355 

356 def build_filter( 

357 self, project_ids: list[str] | None = None 

358 ) -> models.Filter | None: 

359 """Public wrapper for building a Qdrant filter for project constraints. 

360 

361 Prefer using `FieldQueryParser.create_qdrant_filter` for field queries. This 

362 method exists to expose project filter building via a public API and wraps the 

363 legacy `_build_filter` implementation for compatibility. 

364 

365 Args: 

366 project_ids: Optional list of project IDs to filter by. 

367 

368 Returns: 

369 A Qdrant `models.Filter` or `None` if no filtering is needed. 

370 """ 

371 return self._build_filter(project_ids)