Coverage for src / qdrant_loader_mcp_server / search / components / keyword_search_service.py: 93%
103 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
1"""Keyword search service for hybrid search."""
3from __future__ import annotations
5import asyncio
6from typing import TYPE_CHECKING, Any
8import nltk
9import numpy as np
10from nltk.stem import SnowballStemmer
11from nltk.tokenize import RegexpTokenizer
12from rank_bm25 import BM25Okapi
14if TYPE_CHECKING:
15 from qdrant_client import AsyncQdrantClient
17from ...utils.logging import LoggingConfig
18from .field_query_parser import FieldQueryParser
21class KeywordSearchService:
22 """Handles keyword search operations using BM25."""
24 def __init__(
25 self,
26 qdrant_client: AsyncQdrantClient,
27 collection_name: str,
28 ):
29 """Initialize the keyword search service.
31 Args:
32 qdrant_client: Qdrant client instance
33 collection_name: Name of the Qdrant collection
34 """
35 self.qdrant_client = qdrant_client
36 self.collection_name = collection_name
37 self.field_parser = FieldQueryParser()
38 self.logger = LoggingConfig.get_logger(__name__)
39 self._stemmer = SnowballStemmer(language="english")
41 from nltk.corpus import stopwords
42 try:
43 nltk.data.find("corpora/stopwords")
44 except LookupError:
45 downloaded = nltk.download("stopwords", quiet=True)
46 if not downloaded:
47 self.logger.warning(
48 "NLTK stopwords download failed; continuing without stopword filtering."
49 )
51 try:
52 self._stop_words = set(stopwords.words("english"))
53 except LookupError:
54 self.logger.warning(
55 "NLTK stopwords corpus unavailable; continuing without stopword filtering."
56 )
57 self._stop_words = set()
59 async def keyword_search(
60 self,
61 query: str,
62 limit: int,
63 project_ids: list[str] | None = None,
64 max_candidates: int = 2000,
65 ) -> list[dict[str, Any]]:
66 """Perform keyword search using BM25.
68 Args:
69 query: Search query
70 limit: Maximum number of results
71 project_ids: Optional project ID filters
72 max_candidates: Maximum number of candidate documents to fetch from Qdrant before ranking
74 Returns:
75 List of search results with scores, text, metadata, and source_type
76 """
77 # ✅ Parse query for field-specific filters
78 parsed_query = self.field_parser.parse_query(query)
79 self.logger.debug(
80 f"Keyword search - parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'"
81 )
83 # Create filter combining field queries and project IDs
84 query_filter = self.field_parser.create_qdrant_filter(
85 parsed_query.field_queries, project_ids
86 )
88 # Determine how many candidates to fetch per page: min(max_candidates, scaled_limit)
89 # Using a scale factor to over-fetch relative to requested limit for better ranking quality
90 scale_factor = 5
91 scaled_limit = max(limit * scale_factor, limit)
92 page_limit = min(max_candidates, scaled_limit)
94 # Paginate through Qdrant using scroll until we gather up to max_candidates
95 all_points = []
96 next_offset = None
97 total_fetched = 0
98 while total_fetched < max_candidates:
99 batch_limit = min(page_limit, max_candidates - total_fetched)
100 points, next_offset = await self.qdrant_client.scroll(
101 collection_name=self.collection_name,
102 limit=batch_limit,
103 with_payload=True,
104 with_vectors=False,
105 scroll_filter=query_filter,
106 offset=next_offset,
107 )
109 if not points:
110 break
112 all_points.extend(points)
113 total_fetched += len(points)
115 if not next_offset:
116 break
118 self.logger.debug(
119 f"Keyword search - fetched {len(all_points)} candidates (requested max {max_candidates}, limit {limit})"
120 )
122 documents = []
123 metadata_list = []
124 source_types = []
125 titles = []
126 urls = []
127 document_ids = []
128 sources = []
129 created_ats = []
130 updated_ats = []
132 for point in all_points:
133 if point.payload:
134 content = point.payload.get("content", "")
135 metadata = point.payload.get("metadata", {})
136 source_type = point.payload.get("source_type", "unknown")
137 # Extract fields directly from Qdrant payload
138 title = point.payload.get("title", "")
139 url = point.payload.get("url", "")
140 document_id = point.payload.get("document_id", "")
141 source = point.payload.get("source", "")
142 created_at = point.payload.get("created_at", "")
143 updated_at = point.payload.get("updated_at", "")
145 documents.append(content)
146 metadata_list.append(metadata)
147 source_types.append(source_type)
148 titles.append(title)
149 urls.append(url)
150 document_ids.append(document_id)
151 sources.append(source)
152 created_ats.append(created_at)
153 updated_ats.append(updated_at)
155 if not documents:
156 self.logger.warning("No documents found for keyword search")
157 return []
159 # Handle filter-only searches (no text query for BM25)
160 if self.field_parser.should_use_filter_only(parsed_query):
161 self.logger.debug(
162 "Filter-only search - assigning equal scores to all results"
163 )
164 # For filter-only searches, assign equal scores to all results
165 scores = np.ones(len(documents))
166 else:
167 # Use BM25 scoring for text queries, offloaded to a thread
168 search_query = parsed_query.text_query if parsed_query.text_query else query
169 scores = await asyncio.to_thread(
170 self._compute_bm25_scores, documents, search_query
171 )
173 # Stable sort for ranking to keep original order among ties
174 top_indices = np.array(
175 sorted(range(len(scores)), key=lambda i: (scores[i], i), reverse=True)[
176 :limit
177 ]
178 )
180 results = []
181 for idx in top_indices:
182 if scores[idx] > 0:
183 result = {
184 "score": float(scores[idx]),
185 "text": documents[idx],
186 "metadata": metadata_list[idx],
187 "source_type": source_types[idx],
188 # Include extracted fields from Qdrant payload
189 "title": titles[idx],
190 "url": urls[idx],
191 "document_id": document_ids[idx],
192 "source": sources[idx],
193 "created_at": created_ats[idx],
194 "updated_at": updated_ats[idx],
195 }
197 results.append(result)
199 return results
201 # Note: _build_filter method removed - now using FieldQueryParser.create_qdrant_filter()
202 def _tokenize(self, text: str) -> list[str]:
203 """Tokenize text using NLTK RegexpTokenizer word tokenization.
205 See: https://www.nltk.org/api/nltk.tokenize.regexp.html
206 """
208 if not isinstance(text, str):
209 return []
210 tokenized_text: list[str] = RegexpTokenizer(r"\b\w+\b").tokenize(text)
211 return [
212 self._stemmer.stem(word)
213 for word in tokenized_text
214 if word.lower() not in self._stop_words
215 ]
217 def _compute_bm25_scores(self, documents: list[str], query: str) -> np.ndarray:
218 """Compute BM25 scores for documents against the query.
220 Tokenizes documents and query with NLTK regex word tokenization.
221 """
222 tokenized_docs = [self._tokenize(doc) for doc in documents]
223 bm25 = BM25Okapi(tokenized_docs)
224 tokenized_query = self._tokenize(query)
225 return bm25.get_scores(tokenized_query)