Coverage for src / qdrant_loader_mcp_server / search / components / keyword_search_service.py: 93%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-10 09:41 +0000

1"""Keyword search service for hybrid search.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6from typing import TYPE_CHECKING, Any 

7 

8import nltk 

9import numpy as np 

10from nltk.stem import SnowballStemmer 

11from nltk.tokenize import RegexpTokenizer 

12from rank_bm25 import BM25Okapi 

13 

14if TYPE_CHECKING: 

15 from qdrant_client import AsyncQdrantClient 

16 

17from ...utils.logging import LoggingConfig 

18from .field_query_parser import FieldQueryParser 

19 

20 

21class KeywordSearchService: 

22 """Handles keyword search operations using BM25.""" 

23 

24 def __init__( 

25 self, 

26 qdrant_client: AsyncQdrantClient, 

27 collection_name: str, 

28 ): 

29 """Initialize the keyword search service. 

30 

31 Args: 

32 qdrant_client: Qdrant client instance 

33 collection_name: Name of the Qdrant collection 

34 """ 

35 self.qdrant_client = qdrant_client 

36 self.collection_name = collection_name 

37 self.field_parser = FieldQueryParser() 

38 self.logger = LoggingConfig.get_logger(__name__) 

39 self._stemmer = SnowballStemmer(language="english") 

40 

41 from nltk.corpus import stopwords 

42 

43 try: 

44 nltk.data.find("corpora/stopwords") 

45 except LookupError: 

46 downloaded = nltk.download("stopwords", quiet=True) 

47 if not downloaded: 

48 self.logger.warning( 

49 "NLTK stopwords download failed; continuing without stopword filtering." 

50 ) 

51 

52 try: 

53 self._stop_words = set(stopwords.words("english")) 

54 except LookupError: 

55 self.logger.warning( 

56 "NLTK stopwords corpus unavailable; continuing without stopword filtering." 

57 ) 

58 self._stop_words = set() 

59 

60 async def keyword_search( 

61 self, 

62 query: str, 

63 limit: int, 

64 project_ids: list[str] | None = None, 

65 max_candidates: int = 2000, 

66 ) -> list[dict[str, Any]]: 

67 """Perform keyword search using BM25. 

68 

69 Args: 

70 query: Search query 

71 limit: Maximum number of results 

72 project_ids: Optional project ID filters 

73 max_candidates: Maximum number of candidate documents to fetch from Qdrant before ranking 

74 

75 Returns: 

76 List of search results with scores, text, metadata, and source_type 

77 """ 

78 # ✅ Parse query for field-specific filters 

79 parsed_query = self.field_parser.parse_query(query) 

80 self.logger.debug( 

81 f"Keyword search - parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'" 

82 ) 

83 

84 # Create filter combining field queries and project IDs 

85 query_filter = self.field_parser.create_qdrant_filter( 

86 parsed_query.field_queries, project_ids 

87 ) 

88 

89 # Determine how many candidates to fetch per page: min(max_candidates, scaled_limit) 

90 # Using a scale factor to over-fetch relative to requested limit for better ranking quality 

91 scale_factor = 5 

92 scaled_limit = max(limit * scale_factor, limit) 

93 page_limit = min(max_candidates, scaled_limit) 

94 

95 # Paginate through Qdrant using scroll until we gather up to max_candidates 

96 all_points = [] 

97 next_offset = None 

98 total_fetched = 0 

99 while total_fetched < max_candidates: 

100 batch_limit = min(page_limit, max_candidates - total_fetched) 

101 points, next_offset = await self.qdrant_client.scroll( 

102 collection_name=self.collection_name, 

103 limit=batch_limit, 

104 with_payload=True, 

105 with_vectors=False, 

106 scroll_filter=query_filter, 

107 offset=next_offset, 

108 ) 

109 

110 if not points: 

111 break 

112 

113 all_points.extend(points) 

114 total_fetched += len(points) 

115 

116 if not next_offset: 

117 break 

118 

119 self.logger.debug( 

120 f"Keyword search - fetched {len(all_points)} candidates (requested max {max_candidates}, limit {limit})" 

121 ) 

122 

123 documents = [] 

124 metadata_list = [] 

125 source_types = [] 

126 titles = [] 

127 urls = [] 

128 document_ids = [] 

129 sources = [] 

130 created_ats = [] 

131 updated_ats = [] 

132 contextual_contents = [] 

133 

134 for point in all_points: 

135 if point.payload: 

136 content = point.payload.get("content", "") 

137 metadata = point.payload.get("metadata", {}) 

138 source_type = point.payload.get("source_type", "unknown") 

139 # Extract fields directly from Qdrant payload 

140 title = point.payload.get("title", "") 

141 url = point.payload.get("url", "") 

142 document_id = point.payload.get("document_id", "") 

143 source = point.payload.get("source", "") 

144 created_at = point.payload.get("created_at", "") 

145 updated_at = point.payload.get("updated_at", "") 

146 contextual_content = point.payload.get("contextual_content", "") 

147 

148 documents.append(content) 

149 metadata_list.append(metadata) 

150 source_types.append(source_type) 

151 titles.append(title) 

152 urls.append(url) 

153 document_ids.append(document_id) 

154 sources.append(source) 

155 created_ats.append(created_at) 

156 updated_ats.append(updated_at) 

157 contextual_contents.append(contextual_content) 

158 

159 if not documents: 

160 self.logger.warning("No documents found for keyword search") 

161 return [] 

162 

163 # Handle filter-only searches (no text query for BM25) 

164 if self.field_parser.should_use_filter_only(parsed_query): 

165 self.logger.debug( 

166 "Filter-only search - assigning equal scores to all results" 

167 ) 

168 # For filter-only searches, assign equal scores to all results 

169 scores = np.ones(len(documents)) 

170 else: 

171 # Use BM25 scoring for text queries, offloaded to a thread 

172 search_query = parsed_query.text_query if parsed_query.text_query else query 

173 scores = await asyncio.to_thread( 

174 self._compute_bm25_scores, documents, search_query 

175 ) 

176 

177 # Stable sort for ranking to keep original order among ties 

178 top_indices = np.array( 

179 sorted(range(len(scores)), key=lambda i: (scores[i], i), reverse=True)[ 

180 :limit 

181 ] 

182 ) 

183 

184 results = [] 

185 for idx in top_indices: 

186 if scores[idx] > 0: 

187 result = { 

188 "score": float(scores[idx]), 

189 "text": documents[idx], 

190 "metadata": metadata_list[idx], 

191 "source_type": source_types[idx], 

192 # Include extracted fields from Qdrant payload 

193 "title": titles[idx], 

194 "url": urls[idx], 

195 "document_id": document_ids[idx], 

196 "source": sources[idx], 

197 "created_at": created_ats[idx], 

198 "updated_at": updated_ats[idx], 

199 "contextual_content": contextual_contents[idx], 

200 } 

201 

202 results.append(result) 

203 

204 return results 

205 

206 # Note: _build_filter method removed - now using FieldQueryParser.create_qdrant_filter() 

207 def _tokenize(self, text: str) -> list[str]: 

208 """Tokenize text using NLTK RegexpTokenizer word tokenization. 

209 

210 See: https://www.nltk.org/api/nltk.tokenize.regexp.html 

211 """ 

212 

213 if not isinstance(text, str): 

214 return [] 

215 tokenized_text: list[str] = RegexpTokenizer(r"\b\w+\b").tokenize(text) 

216 return [ 

217 self._stemmer.stem(word) 

218 for word in tokenized_text 

219 if word.lower() not in self._stop_words 

220 ] 

221 

222 def _compute_bm25_scores(self, documents: list[str], query: str) -> np.ndarray: 

223 """Compute BM25 scores for documents against the query. 

224 

225 Tokenizes documents and query with NLTK regex word tokenization. 

226 """ 

227 tokenized_docs = [self._tokenize(doc) for doc in documents] 

228 bm25 = BM25Okapi(tokenized_docs) 

229 tokenized_query = self._tokenize(query) 

230 return bm25.get_scores(tokenized_query)