Coverage for src/qdrant_loader_mcp_server/search/components/vector_search_service.py: 93%

103 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-13 09:20 +0000

1"""Vector search service for hybrid search.""" 

2 

3import hashlib 

4import time 

5from asyncio import Lock 

6from dataclasses import dataclass 

7from typing import Any 

8 

9from openai import AsyncOpenAI 

10from qdrant_client import AsyncQdrantClient 

11from qdrant_client.http import models 

12 

13from ...utils.logging import LoggingConfig 

14from .field_query_parser import FieldQueryParser 

15 

16 

17@dataclass 

18class FilterResult: 

19 score: float 

20 payload: dict 

21 

22 

23class VectorSearchService: 

24 """Handles vector search operations using Qdrant.""" 

25 

26 def __init__( 

27 self, 

28 qdrant_client: AsyncQdrantClient, 

29 openai_client: AsyncOpenAI, 

30 collection_name: str, 

31 min_score: float = 0.3, 

32 cache_enabled: bool = True, 

33 cache_ttl: int = 300, 

34 cache_max_size: int = 500, 

35 hnsw_ef: int = 128, 

36 use_exact_search: bool = False, 

37 ): 

38 """Initialize the vector search service. 

39 

40 Args: 

41 qdrant_client: Asynchronous Qdrant client instance (AsyncQdrantClient) 

42 openai_client: OpenAI client instance 

43 collection_name: Name of the Qdrant collection 

44 min_score: Minimum score threshold 

45 cache_enabled: Whether to enable search result caching 

46 cache_ttl: Cache time-to-live in seconds 

47 cache_max_size: Maximum number of cached results 

48 """ 

49 self.qdrant_client = qdrant_client 

50 self.openai_client = openai_client 

51 self.collection_name = collection_name 

52 self.min_score = min_score 

53 

54 # Search result caching configuration 

55 self.cache_enabled = cache_enabled 

56 self.cache_ttl = cache_ttl 

57 self.cache_max_size = cache_max_size 

58 self._search_cache: dict[str, dict[str, Any]] = {} 

59 self._cache_lock: Lock = Lock() 

60 

61 # Cache performance metrics 

62 self._cache_hits = 0 

63 self._cache_misses = 0 

64 

65 # Field query parser for handling field:value syntax 

66 self.field_parser = FieldQueryParser() 

67 

68 self.logger = LoggingConfig.get_logger(__name__) 

69 

70 # Qdrant search parameters 

71 self.hnsw_ef = hnsw_ef 

72 self.use_exact_search = use_exact_search 

73 

74 def _generate_cache_key( 

75 self, query: str, limit: int, project_ids: list[str] | None = None 

76 ) -> str: 

77 """Generate a cache key for search parameters. 

78 

79 Args: 

80 query: Search query 

81 limit: Maximum number of results 

82 project_ids: Optional project ID filters 

83 

84 Returns: 

85 SHA256 hash of search parameters for cache key 

86 """ 

87 # Create a deterministic string from search parameters 

88 project_str = ",".join(sorted(project_ids)) if project_ids else "none" 

89 cache_input = ( 

90 f"{query}|{limit}|{project_str}|{self.min_score}|{self.collection_name}" 

91 ) 

92 return hashlib.sha256(cache_input.encode()).hexdigest() 

93 

94 def _cleanup_expired_cache(self) -> None: 

95 """Remove expired entries from cache.""" 

96 if not self.cache_enabled: 

97 return 

98 

99 current_time = time.time() 

100 expired_keys = [ 

101 key 

102 for key, value in self._search_cache.items() 

103 if current_time - value["timestamp"] > self.cache_ttl 

104 ] 

105 

106 for key in expired_keys: 

107 del self._search_cache[key] 

108 

109 # Also enforce max size limit 

110 if len(self._search_cache) > self.cache_max_size: 

111 # Remove oldest entries (simple FIFO eviction) 

112 sorted_items = sorted( 

113 self._search_cache.items(), key=lambda x: x[1]["timestamp"] 

114 ) 

115 items_to_remove = len(self._search_cache) - self.cache_max_size 

116 for key, _ in sorted_items[:items_to_remove]: 

117 del self._search_cache[key] 

118 

119 async def get_embedding(self, text: str) -> list[float]: 

120 """Get embedding for text using OpenAI. 

121 

122 Args: 

123 text: Text to get embedding for 

124 

125 Returns: 

126 List of embedding values 

127 

128 Raises: 

129 Exception: If embedding generation fails 

130 """ 

131 try: 

132 response = await self.openai_client.embeddings.create( 

133 model="text-embedding-3-small", 

134 input=text, 

135 ) 

136 return response.data[0].embedding 

137 except Exception as e: 

138 self.logger.error("Failed to get embedding", error=str(e)) 

139 raise 

140 

141 async def vector_search( 

142 self, query: str, limit: int, project_ids: list[str] | None = None 

143 ) -> list[dict[str, Any]]: 

144 """Perform vector search using Qdrant with caching support. 

145 

146 Args: 

147 query: Search query 

148 limit: Maximum number of results 

149 project_ids: Optional project ID filters 

150 

151 Returns: 

152 List of search results with scores, text, metadata, and source_type 

153 """ 

154 # Generate cache key and check cache first 

155 cache_key = self._generate_cache_key(query, limit, project_ids) 

156 

157 if self.cache_enabled: 

158 # Guard cache reads/cleanup with the async lock 

159 async with self._cache_lock: 

160 self._cleanup_expired_cache() 

161 

162 # Check cache for existing results 

163 cached_entry = self._search_cache.get(cache_key) 

164 if cached_entry is not None: 

165 current_time = time.time() 

166 

167 # Verify cache entry is still valid 

168 if current_time - cached_entry["timestamp"] <= self.cache_ttl: 

169 self._cache_hits += 1 

170 self.logger.debug( 

171 "Search cache hit", 

172 query=query[:50], # Truncate for logging 

173 cache_hits=self._cache_hits, 

174 cache_misses=self._cache_misses, 

175 hit_rate=f"{self._cache_hits / (self._cache_hits + self._cache_misses) * 100:.1f}%", 

176 ) 

177 return cached_entry["results"] 

178 

179 # Cache miss - perform actual search 

180 self._cache_misses += 1 

181 

182 self.logger.debug( 

183 "Search cache miss - performing QDrant search", 

184 query=query[:50], # Truncate for logging 

185 cache_hits=self._cache_hits, 

186 cache_misses=self._cache_misses, 

187 ) 

188 

189 # ✅ Parse query for field-specific filters 

190 parsed_query = self.field_parser.parse_query(query) 

191 self.logger.debug( 

192 f"Parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'" 

193 ) 

194 

195 # Determine search strategy based on parsed query 

196 if self.field_parser.should_use_filter_only(parsed_query): 

197 # Filter-only search (exact field matching) 

198 self.logger.debug("Using filter-only search for exact field matching") 

199 query_filter = self.field_parser.create_qdrant_filter( 

200 parsed_query.field_queries, project_ids 

201 ) 

202 

203 # For filter-only searches, use scroll with filter instead of vector search 

204 scroll_results = await self.qdrant_client.scroll( 

205 collection_name=self.collection_name, 

206 limit=limit, 

207 scroll_filter=query_filter, 

208 with_payload=True, 

209 with_vectors=False, 

210 ) 

211 

212 results = [] 

213 for point in scroll_results[ 

214 0 

215 ]: # scroll_results is (points, next_page_offset) 

216 results.append(FilterResult(1.0, point.payload)) 

217 else: 

218 # Hybrid search (vector search + field filters) 

219 search_query = parsed_query.text_query if parsed_query.text_query else query 

220 query_embedding = await self.get_embedding(search_query) 

221 

222 search_params = models.SearchParams( 

223 hnsw_ef=self.hnsw_ef, exact=bool(self.use_exact_search) 

224 ) 

225 

226 # Combine field filters with project filters 

227 query_filter = self.field_parser.create_qdrant_filter( 

228 parsed_query.field_queries, project_ids 

229 ) 

230 

231 results = await self.qdrant_client.search( 

232 collection_name=self.collection_name, 

233 query_vector=query_embedding, 

234 limit=limit, 

235 score_threshold=self.min_score, 

236 search_params=search_params, 

237 query_filter=query_filter, 

238 with_payload=True, # 🔧 CRITICAL: Explicitly request payload data 

239 ) 

240 

241 extracted_results = [] 

242 for hit in results: 

243 extracted = { 

244 "score": hit.score, 

245 "text": hit.payload.get("content", "") if hit.payload else "", 

246 "metadata": hit.payload.get("metadata", {}) if hit.payload else {}, 

247 "source_type": ( 

248 hit.payload.get("source_type", "unknown") 

249 if hit.payload 

250 else "unknown" 

251 ), 

252 # Extract fields directly from Qdrant payload 

253 "title": hit.payload.get("title", "") if hit.payload else "", 

254 "url": hit.payload.get("url", "") if hit.payload else "", 

255 "document_id": ( 

256 hit.payload.get("document_id", "") if hit.payload else "" 

257 ), 

258 "source": hit.payload.get("source", "") if hit.payload else "", 

259 "created_at": hit.payload.get("created_at", "") if hit.payload else "", 

260 "updated_at": hit.payload.get("updated_at", "") if hit.payload else "", 

261 } 

262 

263 extracted_results.append(extracted) 

264 

265 # Store results in cache if caching is enabled 

266 if self.cache_enabled: 

267 async with self._cache_lock: 

268 self._search_cache[cache_key] = { 

269 "results": extracted_results, 

270 "timestamp": time.time(), 

271 } 

272 

273 self.logger.debug( 

274 "Cached search results", 

275 query=query[:50], 

276 results_count=len(extracted_results), 

277 cache_size=len(self._search_cache), 

278 ) 

279 

280 return extracted_results 

281 

282 def get_cache_stats(self) -> dict[str, Any]: 

283 """Get cache performance statistics. 

284 

285 Returns: 

286 Dictionary with cache hit rate, size, and other metrics 

287 """ 

288 total_requests = self._cache_hits + self._cache_misses 

289 hit_rate = ( 

290 (self._cache_hits / total_requests * 100) if total_requests > 0 else 0.0 

291 ) 

292 

293 return { 

294 "cache_enabled": self.cache_enabled, 

295 "cache_hits": self._cache_hits, 

296 "cache_misses": self._cache_misses, 

297 "hit_rate_percent": round(hit_rate, 2), 

298 "cache_size": len(self._search_cache), 

299 "cache_max_size": self.cache_max_size, 

300 "cache_ttl_seconds": self.cache_ttl, 

301 } 

302 

303 def clear_cache(self) -> None: 

304 """Clear all cached search results.""" 

305 self._search_cache.clear() 

306 self.logger.info("Search result cache cleared") 

307 

308 def _build_filter( 

309 self, project_ids: list[str] | None = None 

310 ) -> models.Filter | None: 

311 """Legacy method for backward compatibility - use FieldQueryParser instead. 

312 

313 Args: 

314 project_ids: Optional project ID filters 

315 

316 Returns: 

317 Qdrant Filter object or None 

318 """ 

319 if project_ids: 

320 from qdrant_client.http import models 

321 

322 return models.Filter( 

323 must=[ 

324 models.FieldCondition( 

325 key="project_id", match=models.MatchAny(any=project_ids) 

326 ) 

327 ] 

328 ) 

329 return None 

330 

331 # Note: _build_filter method added back for backward compatibility - prefer FieldQueryParser.create_qdrant_filter()