Coverage for src / qdrant_loader_mcp_server / search / components / keyword_search_service.py: 93%
106 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:41 +0000
1"""Keyword search service for hybrid search."""
3from __future__ import annotations
5import asyncio
6from typing import TYPE_CHECKING, Any
8import nltk
9import numpy as np
10from nltk.stem import SnowballStemmer
11from nltk.tokenize import RegexpTokenizer
12from rank_bm25 import BM25Okapi
14if TYPE_CHECKING:
15 from qdrant_client import AsyncQdrantClient
17from ...utils.logging import LoggingConfig
18from .field_query_parser import FieldQueryParser
21class KeywordSearchService:
22 """Handles keyword search operations using BM25."""
24 def __init__(
25 self,
26 qdrant_client: AsyncQdrantClient,
27 collection_name: str,
28 ):
29 """Initialize the keyword search service.
31 Args:
32 qdrant_client: Qdrant client instance
33 collection_name: Name of the Qdrant collection
34 """
35 self.qdrant_client = qdrant_client
36 self.collection_name = collection_name
37 self.field_parser = FieldQueryParser()
38 self.logger = LoggingConfig.get_logger(__name__)
39 self._stemmer = SnowballStemmer(language="english")
41 from nltk.corpus import stopwords
43 try:
44 nltk.data.find("corpora/stopwords")
45 except LookupError:
46 downloaded = nltk.download("stopwords", quiet=True)
47 if not downloaded:
48 self.logger.warning(
49 "NLTK stopwords download failed; continuing without stopword filtering."
50 )
52 try:
53 self._stop_words = set(stopwords.words("english"))
54 except LookupError:
55 self.logger.warning(
56 "NLTK stopwords corpus unavailable; continuing without stopword filtering."
57 )
58 self._stop_words = set()
60 async def keyword_search(
61 self,
62 query: str,
63 limit: int,
64 project_ids: list[str] | None = None,
65 max_candidates: int = 2000,
66 ) -> list[dict[str, Any]]:
67 """Perform keyword search using BM25.
69 Args:
70 query: Search query
71 limit: Maximum number of results
72 project_ids: Optional project ID filters
73 max_candidates: Maximum number of candidate documents to fetch from Qdrant before ranking
75 Returns:
76 List of search results with scores, text, metadata, and source_type
77 """
78 # ✅ Parse query for field-specific filters
79 parsed_query = self.field_parser.parse_query(query)
80 self.logger.debug(
81 f"Keyword search - parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'"
82 )
84 # Create filter combining field queries and project IDs
85 query_filter = self.field_parser.create_qdrant_filter(
86 parsed_query.field_queries, project_ids
87 )
89 # Determine how many candidates to fetch per page: min(max_candidates, scaled_limit)
90 # Using a scale factor to over-fetch relative to requested limit for better ranking quality
91 scale_factor = 5
92 scaled_limit = max(limit * scale_factor, limit)
93 page_limit = min(max_candidates, scaled_limit)
95 # Paginate through Qdrant using scroll until we gather up to max_candidates
96 all_points = []
97 next_offset = None
98 total_fetched = 0
99 while total_fetched < max_candidates:
100 batch_limit = min(page_limit, max_candidates - total_fetched)
101 points, next_offset = await self.qdrant_client.scroll(
102 collection_name=self.collection_name,
103 limit=batch_limit,
104 with_payload=True,
105 with_vectors=False,
106 scroll_filter=query_filter,
107 offset=next_offset,
108 )
110 if not points:
111 break
113 all_points.extend(points)
114 total_fetched += len(points)
116 if not next_offset:
117 break
119 self.logger.debug(
120 f"Keyword search - fetched {len(all_points)} candidates (requested max {max_candidates}, limit {limit})"
121 )
123 documents = []
124 metadata_list = []
125 source_types = []
126 titles = []
127 urls = []
128 document_ids = []
129 sources = []
130 created_ats = []
131 updated_ats = []
132 contextual_contents = []
134 for point in all_points:
135 if point.payload:
136 content = point.payload.get("content", "")
137 metadata = point.payload.get("metadata", {})
138 source_type = point.payload.get("source_type", "unknown")
139 # Extract fields directly from Qdrant payload
140 title = point.payload.get("title", "")
141 url = point.payload.get("url", "")
142 document_id = point.payload.get("document_id", "")
143 source = point.payload.get("source", "")
144 created_at = point.payload.get("created_at", "")
145 updated_at = point.payload.get("updated_at", "")
146 contextual_content = point.payload.get("contextual_content", "")
148 documents.append(content)
149 metadata_list.append(metadata)
150 source_types.append(source_type)
151 titles.append(title)
152 urls.append(url)
153 document_ids.append(document_id)
154 sources.append(source)
155 created_ats.append(created_at)
156 updated_ats.append(updated_at)
157 contextual_contents.append(contextual_content)
159 if not documents:
160 self.logger.warning("No documents found for keyword search")
161 return []
163 # Handle filter-only searches (no text query for BM25)
164 if self.field_parser.should_use_filter_only(parsed_query):
165 self.logger.debug(
166 "Filter-only search - assigning equal scores to all results"
167 )
168 # For filter-only searches, assign equal scores to all results
169 scores = np.ones(len(documents))
170 else:
171 # Use BM25 scoring for text queries, offloaded to a thread
172 search_query = parsed_query.text_query if parsed_query.text_query else query
173 scores = await asyncio.to_thread(
174 self._compute_bm25_scores, documents, search_query
175 )
177 # Stable sort for ranking to keep original order among ties
178 top_indices = np.array(
179 sorted(range(len(scores)), key=lambda i: (scores[i], i), reverse=True)[
180 :limit
181 ]
182 )
184 results = []
185 for idx in top_indices:
186 if scores[idx] > 0:
187 result = {
188 "score": float(scores[idx]),
189 "text": documents[idx],
190 "metadata": metadata_list[idx],
191 "source_type": source_types[idx],
192 # Include extracted fields from Qdrant payload
193 "title": titles[idx],
194 "url": urls[idx],
195 "document_id": document_ids[idx],
196 "source": sources[idx],
197 "created_at": created_ats[idx],
198 "updated_at": updated_ats[idx],
199 "contextual_content": contextual_contents[idx],
200 }
202 results.append(result)
204 return results
206 # Note: _build_filter method removed - now using FieldQueryParser.create_qdrant_filter()
207 def _tokenize(self, text: str) -> list[str]:
208 """Tokenize text using NLTK RegexpTokenizer word tokenization.
210 See: https://www.nltk.org/api/nltk.tokenize.regexp.html
211 """
213 if not isinstance(text, str):
214 return []
215 tokenized_text: list[str] = RegexpTokenizer(r"\b\w+\b").tokenize(text)
216 return [
217 self._stemmer.stem(word)
218 for word in tokenized_text
219 if word.lower() not in self._stop_words
220 ]
222 def _compute_bm25_scores(self, documents: list[str], query: str) -> np.ndarray:
223 """Compute BM25 scores for documents against the query.
225 Tokenizes documents and query with NLTK regex word tokenization.
226 """
227 tokenized_docs = [self._tokenize(doc) for doc in documents]
228 bm25 = BM25Okapi(tokenized_docs)
229 tokenized_query = self._tokenize(query)
230 return bm25.get_scores(tokenized_query)