Coverage for src/qdrant_loader_mcp_server/search/hybrid_search.py: 100%
187 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:45 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-04 05:45 +0000
1"""Hybrid search implementation combining vector and keyword search."""
3import re
4from dataclasses import dataclass
5from typing import Any, Dict, List, Optional
7import numpy as np
8from openai import AsyncOpenAI
9from qdrant_client import QdrantClient
10from qdrant_client.http import models
11from qdrant_client.models import Filter, PointStruct, ScoredPoint
12from rank_bm25 import BM25Okapi
14from .models import SearchResult
15from ..utils.logging import LoggingConfig
17logger = LoggingConfig.get_logger(__name__)
20@dataclass
21class HybridSearchResult:
22 """Container for hybrid search results."""
24 score: float
25 text: str
26 source_type: str
27 source_title: str
28 source_url: str | None = None
29 file_path: str | None = None
30 repo_name: str | None = None
31 vector_score: float = 0.0
32 keyword_score: float = 0.0
34 # Project information (for multi-project support)
35 project_id: str | None = None
36 project_name: str | None = None
37 project_description: str | None = None
38 collection_name: str | None = None
40 # Hierarchy information (primarily for Confluence)
41 parent_id: str | None = None
42 parent_title: str | None = None
43 breadcrumb_text: str | None = None
44 depth: int | None = None
45 children_count: int | None = None
46 hierarchy_context: str | None = None
48 # Attachment information (for files attached to documents)
49 is_attachment: bool = False
50 parent_document_id: str | None = None
51 parent_document_title: str | None = None
52 attachment_id: str | None = None
53 original_filename: str | None = None
54 file_size: int | None = None
55 mime_type: str | None = None
56 attachment_author: str | None = None
57 attachment_context: str | None = None
60class HybridSearchEngine:
61 """Service for hybrid search combining vector and keyword search."""
63 def __init__(
64 self,
65 qdrant_client: QdrantClient,
66 openai_client: AsyncOpenAI,
67 collection_name: str,
68 vector_weight: float = 0.6,
69 keyword_weight: float = 0.3,
70 metadata_weight: float = 0.1,
71 min_score: float = 0.3,
72 dense_vector_name: str = "dense",
73 sparse_vector_name: str = "sparse",
74 alpha: float = 0.5,
75 ):
76 """Initialize the hybrid search service.
78 Args:
79 qdrant_client: Qdrant client instance
80 openai_client: OpenAI client instance
81 collection_name: Name of the Qdrant collection
82 vector_weight: Weight for vector search scores (0-1)
83 keyword_weight: Weight for keyword search scores (0-1)
84 metadata_weight: Weight for metadata-based scoring (0-1)
85 min_score: Minimum combined score threshold
86 dense_vector_name: Name of the dense vector field
87 sparse_vector_name: Name of the sparse vector field
88 alpha: Weight for dense search (1-alpha for sparse search)
89 """
90 self.qdrant_client = qdrant_client
91 self.openai_client = openai_client
92 self.collection_name = collection_name
93 self.vector_weight = vector_weight
94 self.keyword_weight = keyword_weight
95 self.metadata_weight = metadata_weight
96 self.min_score = min_score
97 self.dense_vector_name = dense_vector_name
98 self.sparse_vector_name = sparse_vector_name
99 self.alpha = alpha
100 self.logger = LoggingConfig.get_logger(__name__)
102 # Common query expansions for frequently used terms
103 self.query_expansions = {
104 "product requirements": [
105 "PRD",
106 "requirements document",
107 "product specification",
108 ],
109 "requirements": ["specs", "requirements document", "features"],
110 "architecture": ["system design", "technical architecture"],
111 "UI": ["user interface", "frontend", "design"],
112 "API": ["interface", "endpoints", "REST"],
113 "database": ["DB", "data storage", "persistence"],
114 "security": ["auth", "authentication", "authorization"],
115 }
117 async def _expand_query(self, query: str) -> str:
118 """Expand query with related terms for better matching."""
119 expanded_query = query
120 lower_query = query.lower()
122 for key, expansions in self.query_expansions.items():
123 if key.lower() in lower_query:
124 expansion_terms = " ".join(expansions)
125 expanded_query = f"{query} {expansion_terms}"
126 self.logger.debug(
127 "Expanded query",
128 original_query=query,
129 expanded_query=expanded_query,
130 )
131 break
133 return expanded_query
135 async def _get_embedding(self, text: str) -> list[float]:
136 """Get embedding for text using OpenAI."""
137 try:
138 response = await self.openai_client.embeddings.create(
139 model="text-embedding-3-small",
140 input=text,
141 )
142 return response.data[0].embedding
143 except Exception as e:
144 self.logger.error("Failed to get embedding", error=str(e))
145 raise
147 async def search(
148 self,
149 query: str,
150 limit: int = 5,
151 source_types: list[str] | None = None,
152 project_ids: list[str] | None = None,
153 ) -> list[SearchResult]:
154 """Perform hybrid search combining vector and keyword search.
156 Args:
157 query: Search query text
158 limit: Maximum number of results to return
159 source_types: Optional list of source types to filter by
160 project_ids: Optional list of project IDs to filter by
161 """
162 self.logger.debug(
163 "Starting hybrid search",
164 query=query,
165 limit=limit,
166 source_types=source_types,
167 project_ids=project_ids,
168 )
170 try:
171 # Expand query with related terms
172 expanded_query = await self._expand_query(query)
174 # Get vector search results
175 vector_results = await self._vector_search(
176 expanded_query, limit * 3, project_ids
177 )
179 # Get keyword search results
180 keyword_results = await self._keyword_search(query, limit * 3, project_ids)
182 # Analyze query for context
183 query_context = self._analyze_query(query)
185 # Combine and rerank results
186 combined_results = await self._combine_results(
187 vector_results,
188 keyword_results,
189 query_context,
190 limit,
191 source_types,
192 project_ids,
193 )
195 # Convert to SearchResult objects
196 return [
197 SearchResult(
198 score=result.score,
199 text=result.text,
200 source_type=result.source_type,
201 source_title=result.source_title,
202 source_url=result.source_url,
203 file_path=result.file_path,
204 repo_name=result.repo_name,
205 project_id=result.project_id,
206 project_name=result.project_name,
207 project_description=result.project_description,
208 collection_name=result.collection_name,
209 parent_id=result.parent_id,
210 parent_title=result.parent_title,
211 breadcrumb_text=result.breadcrumb_text,
212 depth=result.depth,
213 children_count=result.children_count,
214 hierarchy_context=result.hierarchy_context,
215 is_attachment=result.is_attachment,
216 parent_document_id=result.parent_document_id,
217 parent_document_title=result.parent_document_title,
218 attachment_id=result.attachment_id,
219 original_filename=result.original_filename,
220 file_size=result.file_size,
221 mime_type=result.mime_type,
222 attachment_author=result.attachment_author,
223 attachment_context=result.attachment_context,
224 )
225 for result in combined_results
226 ]
228 except Exception as e:
229 self.logger.error("Error in hybrid search", error=str(e), query=query)
230 raise
232 def _analyze_query(self, query: str) -> dict[str, Any]:
233 """Analyze query to determine intent and context."""
234 context = {
235 "is_question": bool(
236 re.search(r"\?|what|how|why|when|who|where", query.lower())
237 ),
238 "is_broad": len(query.split()) < 5,
239 "is_specific": len(query.split()) > 7,
240 "probable_intent": "informational",
241 "keywords": [
242 word.lower() for word in re.findall(r"\b\w{3,}\b", query.lower())
243 ],
244 }
246 lower_query = query.lower()
247 if "how to" in lower_query or "steps" in lower_query:
248 context["probable_intent"] = "procedural"
249 elif any(
250 term in lower_query for term in ["requirements", "prd", "specification"]
251 ):
252 context["probable_intent"] = "requirements"
253 elif any(
254 term in lower_query for term in ["architecture", "design", "structure"]
255 ):
256 context["probable_intent"] = "architecture"
258 return context
260 async def _vector_search(
261 self, query: str, limit: int, project_ids: list[str] | None = None
262 ) -> list[dict[str, Any]]:
263 """Perform vector search using Qdrant."""
264 query_embedding = await self._get_embedding(query)
266 search_params = models.SearchParams(hnsw_ef=128, exact=False)
268 results = self.qdrant_client.search(
269 collection_name=self.collection_name,
270 query_vector=query_embedding,
271 limit=limit,
272 score_threshold=self.min_score,
273 search_params=search_params,
274 query_filter=self._build_filter(project_ids),
275 )
277 return [
278 {
279 "score": hit.score,
280 "text": hit.payload.get("content", "") if hit.payload else "",
281 "metadata": hit.payload.get("metadata", {}) if hit.payload else {},
282 "source_type": (
283 hit.payload.get("source_type", "unknown")
284 if hit.payload
285 else "unknown"
286 ),
287 }
288 for hit in results
289 ]
291 async def _keyword_search(
292 self, query: str, limit: int, project_ids: list[str] | None = None
293 ) -> list[dict[str, Any]]:
294 """Perform keyword search using BM25."""
295 scroll_results = self.qdrant_client.scroll(
296 collection_name=self.collection_name,
297 limit=10000,
298 with_payload=True,
299 with_vectors=False,
300 scroll_filter=self._build_filter(project_ids),
301 )
303 documents = []
304 metadata_list = []
305 source_types = []
307 for point in scroll_results[0]:
308 if point.payload:
309 content = point.payload.get("content", "")
310 metadata = point.payload.get("metadata", {})
311 source_type = point.payload.get("source_type", "unknown")
312 documents.append(content)
313 metadata_list.append(metadata)
314 source_types.append(source_type)
316 tokenized_docs = [doc.split() for doc in documents]
317 bm25 = BM25Okapi(tokenized_docs)
319 tokenized_query = query.split()
320 scores = bm25.get_scores(tokenized_query)
322 top_indices = np.argsort(scores)[-limit:][::-1]
324 return [
325 {
326 "score": float(scores[idx]),
327 "text": documents[idx],
328 "metadata": metadata_list[idx],
329 "source_type": source_types[idx],
330 }
331 for idx in top_indices
332 if scores[idx] > 0
333 ]
335 async def _combine_results(
336 self,
337 vector_results: list[dict[str, Any]],
338 keyword_results: list[dict[str, Any]],
339 query_context: dict[str, Any],
340 limit: int,
341 source_types: list[str] | None = None,
342 project_ids: list[str] | None = None,
343 ) -> list[HybridSearchResult]:
344 """Combine and rerank results from vector and keyword search."""
345 combined_dict = {}
347 # Process vector results
348 for result in vector_results:
349 text = result["text"]
350 if text not in combined_dict:
351 metadata = result["metadata"]
352 combined_dict[text] = {
353 "text": text,
354 "metadata": metadata,
355 "source_type": result["source_type"],
356 "vector_score": result["score"],
357 "keyword_score": 0.0,
358 }
360 # Process keyword results
361 for result in keyword_results:
362 text = result["text"]
363 if text in combined_dict:
364 combined_dict[text]["keyword_score"] = result["score"]
365 else:
366 metadata = result["metadata"]
367 combined_dict[text] = {
368 "text": text,
369 "metadata": metadata,
370 "source_type": result["source_type"],
371 "vector_score": 0.0,
372 "keyword_score": result["score"],
373 }
375 # Calculate combined scores and create results
376 combined_results = []
377 for text, info in combined_dict.items():
378 # Skip if source type doesn't match filter
379 if source_types and info["source_type"] not in source_types:
380 continue
382 metadata = info["metadata"]
383 combined_score = (
384 self.vector_weight * info["vector_score"]
385 + self.keyword_weight * info["keyword_score"]
386 )
388 if combined_score >= self.min_score:
389 # Extract hierarchy information
390 hierarchy_info = self._extract_metadata_info(metadata)
392 # Extract project information
393 project_info = self._extract_project_info(metadata)
395 combined_results.append(
396 HybridSearchResult(
397 score=combined_score,
398 text=text,
399 source_type=info["source_type"],
400 source_title=metadata.get("title", ""),
401 source_url=metadata.get("url"),
402 file_path=metadata.get("file_path"),
403 repo_name=metadata.get("repository_name"),
404 vector_score=info["vector_score"],
405 keyword_score=info["keyword_score"],
406 project_id=project_info["project_id"],
407 project_name=project_info["project_name"],
408 project_description=project_info["project_description"],
409 collection_name=project_info["collection_name"],
410 parent_id=hierarchy_info["parent_id"],
411 parent_title=hierarchy_info["parent_title"],
412 breadcrumb_text=hierarchy_info["breadcrumb_text"],
413 depth=hierarchy_info["depth"],
414 children_count=hierarchy_info["children_count"],
415 hierarchy_context=hierarchy_info["hierarchy_context"],
416 is_attachment=hierarchy_info["is_attachment"],
417 parent_document_id=hierarchy_info["parent_document_id"],
418 parent_document_title=hierarchy_info["parent_document_title"],
419 attachment_id=hierarchy_info["attachment_id"],
420 original_filename=hierarchy_info["original_filename"],
421 file_size=hierarchy_info["file_size"],
422 mime_type=hierarchy_info["mime_type"],
423 attachment_author=hierarchy_info["attachment_author"],
424 attachment_context=hierarchy_info["attachment_context"],
425 )
426 )
428 # Sort by combined score
429 combined_results.sort(key=lambda x: x.score, reverse=True)
430 return combined_results[:limit]
432 def _extract_metadata_info(self, metadata: dict) -> dict:
433 """Extract hierarchy and attachment information from document metadata.
435 Args:
436 metadata: Document metadata
438 Returns:
439 Dictionary with hierarchy and attachment information
440 """
441 # Extract hierarchy information
442 hierarchy_info = {
443 "parent_id": metadata.get("parent_id"),
444 "parent_title": metadata.get("parent_title"),
445 "breadcrumb_text": metadata.get("breadcrumb_text"),
446 "depth": metadata.get("depth"),
447 "children_count": None,
448 "hierarchy_context": None,
449 }
451 # Calculate children count
452 children = metadata.get("children", [])
453 if children:
454 hierarchy_info["children_count"] = len(children)
456 # Generate hierarchy context for display
457 if metadata.get("breadcrumb_text") or metadata.get("depth") is not None:
458 context_parts = []
460 if metadata.get("breadcrumb_text"):
461 context_parts.append(f"Path: {metadata.get('breadcrumb_text')}")
463 if metadata.get("depth") is not None:
464 context_parts.append(f"Depth: {metadata.get('depth')}")
466 if (
467 hierarchy_info["children_count"] is not None
468 and hierarchy_info["children_count"] > 0
469 ):
470 context_parts.append(f"Children: {hierarchy_info['children_count']}")
472 if context_parts:
473 hierarchy_info["hierarchy_context"] = " | ".join(context_parts)
475 # Extract attachment information
476 attachment_info = {
477 "is_attachment": metadata.get("is_attachment", False),
478 "parent_document_id": metadata.get("parent_document_id"),
479 "parent_document_title": metadata.get("parent_document_title"),
480 "attachment_id": metadata.get("attachment_id"),
481 "original_filename": metadata.get("original_filename"),
482 "file_size": metadata.get("file_size"),
483 "mime_type": metadata.get("mime_type"),
484 "attachment_author": metadata.get("attachment_author")
485 or metadata.get("author"),
486 "attachment_context": None,
487 }
489 # Generate attachment context for display
490 if attachment_info["is_attachment"]:
491 context_parts = []
493 if attachment_info["original_filename"]:
494 context_parts.append(f"File: {attachment_info['original_filename']}")
496 if attachment_info["file_size"]:
497 # Convert bytes to human readable format
498 size = attachment_info["file_size"]
499 if size < 1024:
500 size_str = f"{size} B"
501 elif size < 1024 * 1024:
502 size_str = f"{size / 1024:.1f} KB"
503 elif size < 1024 * 1024 * 1024:
504 size_str = f"{size / (1024 * 1024):.1f} MB"
505 else:
506 size_str = f"{size / (1024 * 1024 * 1024):.1f} GB"
507 context_parts.append(f"Size: {size_str}")
509 if attachment_info["mime_type"]:
510 context_parts.append(f"Type: {attachment_info['mime_type']}")
512 if attachment_info["attachment_author"]:
513 context_parts.append(f"Author: {attachment_info['attachment_author']}")
515 if context_parts:
516 attachment_info["attachment_context"] = " | ".join(context_parts)
518 # Combine both hierarchy and attachment info
519 return {**hierarchy_info, **attachment_info}
521 def _extract_project_info(self, metadata: dict) -> dict:
522 """Extract project information from document metadata.
524 Args:
525 metadata: Document metadata
527 Returns:
528 Dictionary with project information
529 """
530 return {
531 "project_id": metadata.get("project_id"),
532 "project_name": metadata.get("project_name"),
533 "project_description": metadata.get("project_description"),
534 "collection_name": metadata.get("collection_name"),
535 }
537 def _build_filter(
538 self, project_ids: list[str] | None = None
539 ) -> models.Filter | None:
540 """Build a Qdrant filter based on project IDs."""
541 if not project_ids:
542 return None
544 return models.Filter(
545 must=[
546 models.FieldCondition(
547 key="project_id", match=models.MatchAny(any=project_ids)
548 )
549 ]
550 )