Coverage for src/qdrant_loader_mcp_server/search/hybrid_search.py: 100%

187 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-04 05:45 +0000

1"""Hybrid search implementation combining vector and keyword search.""" 

2 

3import re 

4from dataclasses import dataclass 

5from typing import Any, Dict, List, Optional 

6 

7import numpy as np 

8from openai import AsyncOpenAI 

9from qdrant_client import QdrantClient 

10from qdrant_client.http import models 

11from qdrant_client.models import Filter, PointStruct, ScoredPoint 

12from rank_bm25 import BM25Okapi 

13 

14from .models import SearchResult 

15from ..utils.logging import LoggingConfig 

16 

17logger = LoggingConfig.get_logger(__name__) 

18 

19 

20@dataclass 

21class HybridSearchResult: 

22 """Container for hybrid search results.""" 

23 

24 score: float 

25 text: str 

26 source_type: str 

27 source_title: str 

28 source_url: str | None = None 

29 file_path: str | None = None 

30 repo_name: str | None = None 

31 vector_score: float = 0.0 

32 keyword_score: float = 0.0 

33 

34 # Project information (for multi-project support) 

35 project_id: str | None = None 

36 project_name: str | None = None 

37 project_description: str | None = None 

38 collection_name: str | None = None 

39 

40 # Hierarchy information (primarily for Confluence) 

41 parent_id: str | None = None 

42 parent_title: str | None = None 

43 breadcrumb_text: str | None = None 

44 depth: int | None = None 

45 children_count: int | None = None 

46 hierarchy_context: str | None = None 

47 

48 # Attachment information (for files attached to documents) 

49 is_attachment: bool = False 

50 parent_document_id: str | None = None 

51 parent_document_title: str | None = None 

52 attachment_id: str | None = None 

53 original_filename: str | None = None 

54 file_size: int | None = None 

55 mime_type: str | None = None 

56 attachment_author: str | None = None 

57 attachment_context: str | None = None 

58 

59 

60class HybridSearchEngine: 

61 """Service for hybrid search combining vector and keyword search.""" 

62 

63 def __init__( 

64 self, 

65 qdrant_client: QdrantClient, 

66 openai_client: AsyncOpenAI, 

67 collection_name: str, 

68 vector_weight: float = 0.6, 

69 keyword_weight: float = 0.3, 

70 metadata_weight: float = 0.1, 

71 min_score: float = 0.3, 

72 dense_vector_name: str = "dense", 

73 sparse_vector_name: str = "sparse", 

74 alpha: float = 0.5, 

75 ): 

76 """Initialize the hybrid search service. 

77 

78 Args: 

79 qdrant_client: Qdrant client instance 

80 openai_client: OpenAI client instance 

81 collection_name: Name of the Qdrant collection 

82 vector_weight: Weight for vector search scores (0-1) 

83 keyword_weight: Weight for keyword search scores (0-1) 

84 metadata_weight: Weight for metadata-based scoring (0-1) 

85 min_score: Minimum combined score threshold 

86 dense_vector_name: Name of the dense vector field 

87 sparse_vector_name: Name of the sparse vector field 

88 alpha: Weight for dense search (1-alpha for sparse search) 

89 """ 

90 self.qdrant_client = qdrant_client 

91 self.openai_client = openai_client 

92 self.collection_name = collection_name 

93 self.vector_weight = vector_weight 

94 self.keyword_weight = keyword_weight 

95 self.metadata_weight = metadata_weight 

96 self.min_score = min_score 

97 self.dense_vector_name = dense_vector_name 

98 self.sparse_vector_name = sparse_vector_name 

99 self.alpha = alpha 

100 self.logger = LoggingConfig.get_logger(__name__) 

101 

102 # Common query expansions for frequently used terms 

103 self.query_expansions = { 

104 "product requirements": [ 

105 "PRD", 

106 "requirements document", 

107 "product specification", 

108 ], 

109 "requirements": ["specs", "requirements document", "features"], 

110 "architecture": ["system design", "technical architecture"], 

111 "UI": ["user interface", "frontend", "design"], 

112 "API": ["interface", "endpoints", "REST"], 

113 "database": ["DB", "data storage", "persistence"], 

114 "security": ["auth", "authentication", "authorization"], 

115 } 

116 

117 async def _expand_query(self, query: str) -> str: 

118 """Expand query with related terms for better matching.""" 

119 expanded_query = query 

120 lower_query = query.lower() 

121 

122 for key, expansions in self.query_expansions.items(): 

123 if key.lower() in lower_query: 

124 expansion_terms = " ".join(expansions) 

125 expanded_query = f"{query} {expansion_terms}" 

126 self.logger.debug( 

127 "Expanded query", 

128 original_query=query, 

129 expanded_query=expanded_query, 

130 ) 

131 break 

132 

133 return expanded_query 

134 

135 async def _get_embedding(self, text: str) -> list[float]: 

136 """Get embedding for text using OpenAI.""" 

137 try: 

138 response = await self.openai_client.embeddings.create( 

139 model="text-embedding-3-small", 

140 input=text, 

141 ) 

142 return response.data[0].embedding 

143 except Exception as e: 

144 self.logger.error("Failed to get embedding", error=str(e)) 

145 raise 

146 

147 async def search( 

148 self, 

149 query: str, 

150 limit: int = 5, 

151 source_types: list[str] | None = None, 

152 project_ids: list[str] | None = None, 

153 ) -> list[SearchResult]: 

154 """Perform hybrid search combining vector and keyword search. 

155 

156 Args: 

157 query: Search query text 

158 limit: Maximum number of results to return 

159 source_types: Optional list of source types to filter by 

160 project_ids: Optional list of project IDs to filter by 

161 """ 

162 self.logger.debug( 

163 "Starting hybrid search", 

164 query=query, 

165 limit=limit, 

166 source_types=source_types, 

167 project_ids=project_ids, 

168 ) 

169 

170 try: 

171 # Expand query with related terms 

172 expanded_query = await self._expand_query(query) 

173 

174 # Get vector search results 

175 vector_results = await self._vector_search( 

176 expanded_query, limit * 3, project_ids 

177 ) 

178 

179 # Get keyword search results 

180 keyword_results = await self._keyword_search(query, limit * 3, project_ids) 

181 

182 # Analyze query for context 

183 query_context = self._analyze_query(query) 

184 

185 # Combine and rerank results 

186 combined_results = await self._combine_results( 

187 vector_results, 

188 keyword_results, 

189 query_context, 

190 limit, 

191 source_types, 

192 project_ids, 

193 ) 

194 

195 # Convert to SearchResult objects 

196 return [ 

197 SearchResult( 

198 score=result.score, 

199 text=result.text, 

200 source_type=result.source_type, 

201 source_title=result.source_title, 

202 source_url=result.source_url, 

203 file_path=result.file_path, 

204 repo_name=result.repo_name, 

205 project_id=result.project_id, 

206 project_name=result.project_name, 

207 project_description=result.project_description, 

208 collection_name=result.collection_name, 

209 parent_id=result.parent_id, 

210 parent_title=result.parent_title, 

211 breadcrumb_text=result.breadcrumb_text, 

212 depth=result.depth, 

213 children_count=result.children_count, 

214 hierarchy_context=result.hierarchy_context, 

215 is_attachment=result.is_attachment, 

216 parent_document_id=result.parent_document_id, 

217 parent_document_title=result.parent_document_title, 

218 attachment_id=result.attachment_id, 

219 original_filename=result.original_filename, 

220 file_size=result.file_size, 

221 mime_type=result.mime_type, 

222 attachment_author=result.attachment_author, 

223 attachment_context=result.attachment_context, 

224 ) 

225 for result in combined_results 

226 ] 

227 

228 except Exception as e: 

229 self.logger.error("Error in hybrid search", error=str(e), query=query) 

230 raise 

231 

232 def _analyze_query(self, query: str) -> dict[str, Any]: 

233 """Analyze query to determine intent and context.""" 

234 context = { 

235 "is_question": bool( 

236 re.search(r"\?|what|how|why|when|who|where", query.lower()) 

237 ), 

238 "is_broad": len(query.split()) < 5, 

239 "is_specific": len(query.split()) > 7, 

240 "probable_intent": "informational", 

241 "keywords": [ 

242 word.lower() for word in re.findall(r"\b\w{3,}\b", query.lower()) 

243 ], 

244 } 

245 

246 lower_query = query.lower() 

247 if "how to" in lower_query or "steps" in lower_query: 

248 context["probable_intent"] = "procedural" 

249 elif any( 

250 term in lower_query for term in ["requirements", "prd", "specification"] 

251 ): 

252 context["probable_intent"] = "requirements" 

253 elif any( 

254 term in lower_query for term in ["architecture", "design", "structure"] 

255 ): 

256 context["probable_intent"] = "architecture" 

257 

258 return context 

259 

260 async def _vector_search( 

261 self, query: str, limit: int, project_ids: list[str] | None = None 

262 ) -> list[dict[str, Any]]: 

263 """Perform vector search using Qdrant.""" 

264 query_embedding = await self._get_embedding(query) 

265 

266 search_params = models.SearchParams(hnsw_ef=128, exact=False) 

267 

268 results = self.qdrant_client.search( 

269 collection_name=self.collection_name, 

270 query_vector=query_embedding, 

271 limit=limit, 

272 score_threshold=self.min_score, 

273 search_params=search_params, 

274 query_filter=self._build_filter(project_ids), 

275 ) 

276 

277 return [ 

278 { 

279 "score": hit.score, 

280 "text": hit.payload.get("content", "") if hit.payload else "", 

281 "metadata": hit.payload.get("metadata", {}) if hit.payload else {}, 

282 "source_type": ( 

283 hit.payload.get("source_type", "unknown") 

284 if hit.payload 

285 else "unknown" 

286 ), 

287 } 

288 for hit in results 

289 ] 

290 

291 async def _keyword_search( 

292 self, query: str, limit: int, project_ids: list[str] | None = None 

293 ) -> list[dict[str, Any]]: 

294 """Perform keyword search using BM25.""" 

295 scroll_results = self.qdrant_client.scroll( 

296 collection_name=self.collection_name, 

297 limit=10000, 

298 with_payload=True, 

299 with_vectors=False, 

300 scroll_filter=self._build_filter(project_ids), 

301 ) 

302 

303 documents = [] 

304 metadata_list = [] 

305 source_types = [] 

306 

307 for point in scroll_results[0]: 

308 if point.payload: 

309 content = point.payload.get("content", "") 

310 metadata = point.payload.get("metadata", {}) 

311 source_type = point.payload.get("source_type", "unknown") 

312 documents.append(content) 

313 metadata_list.append(metadata) 

314 source_types.append(source_type) 

315 

316 tokenized_docs = [doc.split() for doc in documents] 

317 bm25 = BM25Okapi(tokenized_docs) 

318 

319 tokenized_query = query.split() 

320 scores = bm25.get_scores(tokenized_query) 

321 

322 top_indices = np.argsort(scores)[-limit:][::-1] 

323 

324 return [ 

325 { 

326 "score": float(scores[idx]), 

327 "text": documents[idx], 

328 "metadata": metadata_list[idx], 

329 "source_type": source_types[idx], 

330 } 

331 for idx in top_indices 

332 if scores[idx] > 0 

333 ] 

334 

335 async def _combine_results( 

336 self, 

337 vector_results: list[dict[str, Any]], 

338 keyword_results: list[dict[str, Any]], 

339 query_context: dict[str, Any], 

340 limit: int, 

341 source_types: list[str] | None = None, 

342 project_ids: list[str] | None = None, 

343 ) -> list[HybridSearchResult]: 

344 """Combine and rerank results from vector and keyword search.""" 

345 combined_dict = {} 

346 

347 # Process vector results 

348 for result in vector_results: 

349 text = result["text"] 

350 if text not in combined_dict: 

351 metadata = result["metadata"] 

352 combined_dict[text] = { 

353 "text": text, 

354 "metadata": metadata, 

355 "source_type": result["source_type"], 

356 "vector_score": result["score"], 

357 "keyword_score": 0.0, 

358 } 

359 

360 # Process keyword results 

361 for result in keyword_results: 

362 text = result["text"] 

363 if text in combined_dict: 

364 combined_dict[text]["keyword_score"] = result["score"] 

365 else: 

366 metadata = result["metadata"] 

367 combined_dict[text] = { 

368 "text": text, 

369 "metadata": metadata, 

370 "source_type": result["source_type"], 

371 "vector_score": 0.0, 

372 "keyword_score": result["score"], 

373 } 

374 

375 # Calculate combined scores and create results 

376 combined_results = [] 

377 for text, info in combined_dict.items(): 

378 # Skip if source type doesn't match filter 

379 if source_types and info["source_type"] not in source_types: 

380 continue 

381 

382 metadata = info["metadata"] 

383 combined_score = ( 

384 self.vector_weight * info["vector_score"] 

385 + self.keyword_weight * info["keyword_score"] 

386 ) 

387 

388 if combined_score >= self.min_score: 

389 # Extract hierarchy information 

390 hierarchy_info = self._extract_metadata_info(metadata) 

391 

392 # Extract project information 

393 project_info = self._extract_project_info(metadata) 

394 

395 combined_results.append( 

396 HybridSearchResult( 

397 score=combined_score, 

398 text=text, 

399 source_type=info["source_type"], 

400 source_title=metadata.get("title", ""), 

401 source_url=metadata.get("url"), 

402 file_path=metadata.get("file_path"), 

403 repo_name=metadata.get("repository_name"), 

404 vector_score=info["vector_score"], 

405 keyword_score=info["keyword_score"], 

406 project_id=project_info["project_id"], 

407 project_name=project_info["project_name"], 

408 project_description=project_info["project_description"], 

409 collection_name=project_info["collection_name"], 

410 parent_id=hierarchy_info["parent_id"], 

411 parent_title=hierarchy_info["parent_title"], 

412 breadcrumb_text=hierarchy_info["breadcrumb_text"], 

413 depth=hierarchy_info["depth"], 

414 children_count=hierarchy_info["children_count"], 

415 hierarchy_context=hierarchy_info["hierarchy_context"], 

416 is_attachment=hierarchy_info["is_attachment"], 

417 parent_document_id=hierarchy_info["parent_document_id"], 

418 parent_document_title=hierarchy_info["parent_document_title"], 

419 attachment_id=hierarchy_info["attachment_id"], 

420 original_filename=hierarchy_info["original_filename"], 

421 file_size=hierarchy_info["file_size"], 

422 mime_type=hierarchy_info["mime_type"], 

423 attachment_author=hierarchy_info["attachment_author"], 

424 attachment_context=hierarchy_info["attachment_context"], 

425 ) 

426 ) 

427 

428 # Sort by combined score 

429 combined_results.sort(key=lambda x: x.score, reverse=True) 

430 return combined_results[:limit] 

431 

432 def _extract_metadata_info(self, metadata: dict) -> dict: 

433 """Extract hierarchy and attachment information from document metadata. 

434 

435 Args: 

436 metadata: Document metadata 

437 

438 Returns: 

439 Dictionary with hierarchy and attachment information 

440 """ 

441 # Extract hierarchy information 

442 hierarchy_info = { 

443 "parent_id": metadata.get("parent_id"), 

444 "parent_title": metadata.get("parent_title"), 

445 "breadcrumb_text": metadata.get("breadcrumb_text"), 

446 "depth": metadata.get("depth"), 

447 "children_count": None, 

448 "hierarchy_context": None, 

449 } 

450 

451 # Calculate children count 

452 children = metadata.get("children", []) 

453 if children: 

454 hierarchy_info["children_count"] = len(children) 

455 

456 # Generate hierarchy context for display 

457 if metadata.get("breadcrumb_text") or metadata.get("depth") is not None: 

458 context_parts = [] 

459 

460 if metadata.get("breadcrumb_text"): 

461 context_parts.append(f"Path: {metadata.get('breadcrumb_text')}") 

462 

463 if metadata.get("depth") is not None: 

464 context_parts.append(f"Depth: {metadata.get('depth')}") 

465 

466 if ( 

467 hierarchy_info["children_count"] is not None 

468 and hierarchy_info["children_count"] > 0 

469 ): 

470 context_parts.append(f"Children: {hierarchy_info['children_count']}") 

471 

472 if context_parts: 

473 hierarchy_info["hierarchy_context"] = " | ".join(context_parts) 

474 

475 # Extract attachment information 

476 attachment_info = { 

477 "is_attachment": metadata.get("is_attachment", False), 

478 "parent_document_id": metadata.get("parent_document_id"), 

479 "parent_document_title": metadata.get("parent_document_title"), 

480 "attachment_id": metadata.get("attachment_id"), 

481 "original_filename": metadata.get("original_filename"), 

482 "file_size": metadata.get("file_size"), 

483 "mime_type": metadata.get("mime_type"), 

484 "attachment_author": metadata.get("attachment_author") 

485 or metadata.get("author"), 

486 "attachment_context": None, 

487 } 

488 

489 # Generate attachment context for display 

490 if attachment_info["is_attachment"]: 

491 context_parts = [] 

492 

493 if attachment_info["original_filename"]: 

494 context_parts.append(f"File: {attachment_info['original_filename']}") 

495 

496 if attachment_info["file_size"]: 

497 # Convert bytes to human readable format 

498 size = attachment_info["file_size"] 

499 if size < 1024: 

500 size_str = f"{size} B" 

501 elif size < 1024 * 1024: 

502 size_str = f"{size / 1024:.1f} KB" 

503 elif size < 1024 * 1024 * 1024: 

504 size_str = f"{size / (1024 * 1024):.1f} MB" 

505 else: 

506 size_str = f"{size / (1024 * 1024 * 1024):.1f} GB" 

507 context_parts.append(f"Size: {size_str}") 

508 

509 if attachment_info["mime_type"]: 

510 context_parts.append(f"Type: {attachment_info['mime_type']}") 

511 

512 if attachment_info["attachment_author"]: 

513 context_parts.append(f"Author: {attachment_info['attachment_author']}") 

514 

515 if context_parts: 

516 attachment_info["attachment_context"] = " | ".join(context_parts) 

517 

518 # Combine both hierarchy and attachment info 

519 return {**hierarchy_info, **attachment_info} 

520 

521 def _extract_project_info(self, metadata: dict) -> dict: 

522 """Extract project information from document metadata. 

523 

524 Args: 

525 metadata: Document metadata 

526 

527 Returns: 

528 Dictionary with project information 

529 """ 

530 return { 

531 "project_id": metadata.get("project_id"), 

532 "project_name": metadata.get("project_name"), 

533 "project_description": metadata.get("project_description"), 

534 "collection_name": metadata.get("collection_name"), 

535 } 

536 

537 def _build_filter( 

538 self, project_ids: list[str] | None = None 

539 ) -> models.Filter | None: 

540 """Build a Qdrant filter based on project IDs.""" 

541 if not project_ids: 

542 return None 

543 

544 return models.Filter( 

545 must=[ 

546 models.FieldCondition( 

547 key="project_id", match=models.MatchAny(any=project_ids) 

548 ) 

549 ] 

550 )