Coverage for src / qdrant_loader_mcp_server / search / components / keyword_search_service.py: 93%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 04:51 +0000

1"""Keyword search service for hybrid search.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6from typing import TYPE_CHECKING, Any 

7 

8import nltk 

9import numpy as np 

10from nltk.stem import SnowballStemmer 

11from nltk.tokenize import RegexpTokenizer 

12from rank_bm25 import BM25Okapi 

13 

14if TYPE_CHECKING: 

15 from qdrant_client import AsyncQdrantClient 

16 

17from ...utils.logging import LoggingConfig 

18from .field_query_parser import FieldQueryParser 

19 

20 

21class KeywordSearchService: 

22 """Handles keyword search operations using BM25.""" 

23 

24 def __init__( 

25 self, 

26 qdrant_client: AsyncQdrantClient, 

27 collection_name: str, 

28 ): 

29 """Initialize the keyword search service. 

30 

31 Args: 

32 qdrant_client: Qdrant client instance 

33 collection_name: Name of the Qdrant collection 

34 """ 

35 self.qdrant_client = qdrant_client 

36 self.collection_name = collection_name 

37 self.field_parser = FieldQueryParser() 

38 self.logger = LoggingConfig.get_logger(__name__) 

39 self._stemmer = SnowballStemmer(language="english") 

40 

41 from nltk.corpus import stopwords 

42 try: 

43 nltk.data.find("corpora/stopwords") 

44 except LookupError: 

45 downloaded = nltk.download("stopwords", quiet=True) 

46 if not downloaded: 

47 self.logger.warning( 

48 "NLTK stopwords download failed; continuing without stopword filtering." 

49 ) 

50 

51 try: 

52 self._stop_words = set(stopwords.words("english")) 

53 except LookupError: 

54 self.logger.warning( 

55 "NLTK stopwords corpus unavailable; continuing without stopword filtering." 

56 ) 

57 self._stop_words = set() 

58 

59 async def keyword_search( 

60 self, 

61 query: str, 

62 limit: int, 

63 project_ids: list[str] | None = None, 

64 max_candidates: int = 2000, 

65 ) -> list[dict[str, Any]]: 

66 """Perform keyword search using BM25. 

67 

68 Args: 

69 query: Search query 

70 limit: Maximum number of results 

71 project_ids: Optional project ID filters 

72 max_candidates: Maximum number of candidate documents to fetch from Qdrant before ranking 

73 

74 Returns: 

75 List of search results with scores, text, metadata, and source_type 

76 """ 

77 # ✅ Parse query for field-specific filters 

78 parsed_query = self.field_parser.parse_query(query) 

79 self.logger.debug( 

80 f"Keyword search - parsed query: {len(parsed_query.field_queries)} field queries, text: '{parsed_query.text_query}'" 

81 ) 

82 

83 # Create filter combining field queries and project IDs 

84 query_filter = self.field_parser.create_qdrant_filter( 

85 parsed_query.field_queries, project_ids 

86 ) 

87 

88 # Determine how many candidates to fetch per page: min(max_candidates, scaled_limit) 

89 # Using a scale factor to over-fetch relative to requested limit for better ranking quality 

90 scale_factor = 5 

91 scaled_limit = max(limit * scale_factor, limit) 

92 page_limit = min(max_candidates, scaled_limit) 

93 

94 # Paginate through Qdrant using scroll until we gather up to max_candidates 

95 all_points = [] 

96 next_offset = None 

97 total_fetched = 0 

98 while total_fetched < max_candidates: 

99 batch_limit = min(page_limit, max_candidates - total_fetched) 

100 points, next_offset = await self.qdrant_client.scroll( 

101 collection_name=self.collection_name, 

102 limit=batch_limit, 

103 with_payload=True, 

104 with_vectors=False, 

105 scroll_filter=query_filter, 

106 offset=next_offset, 

107 ) 

108 

109 if not points: 

110 break 

111 

112 all_points.extend(points) 

113 total_fetched += len(points) 

114 

115 if not next_offset: 

116 break 

117 

118 self.logger.debug( 

119 f"Keyword search - fetched {len(all_points)} candidates (requested max {max_candidates}, limit {limit})" 

120 ) 

121 

122 documents = [] 

123 metadata_list = [] 

124 source_types = [] 

125 titles = [] 

126 urls = [] 

127 document_ids = [] 

128 sources = [] 

129 created_ats = [] 

130 updated_ats = [] 

131 

132 for point in all_points: 

133 if point.payload: 

134 content = point.payload.get("content", "") 

135 metadata = point.payload.get("metadata", {}) 

136 source_type = point.payload.get("source_type", "unknown") 

137 # Extract fields directly from Qdrant payload 

138 title = point.payload.get("title", "") 

139 url = point.payload.get("url", "") 

140 document_id = point.payload.get("document_id", "") 

141 source = point.payload.get("source", "") 

142 created_at = point.payload.get("created_at", "") 

143 updated_at = point.payload.get("updated_at", "") 

144 

145 documents.append(content) 

146 metadata_list.append(metadata) 

147 source_types.append(source_type) 

148 titles.append(title) 

149 urls.append(url) 

150 document_ids.append(document_id) 

151 sources.append(source) 

152 created_ats.append(created_at) 

153 updated_ats.append(updated_at) 

154 

155 if not documents: 

156 self.logger.warning("No documents found for keyword search") 

157 return [] 

158 

159 # Handle filter-only searches (no text query for BM25) 

160 if self.field_parser.should_use_filter_only(parsed_query): 

161 self.logger.debug( 

162 "Filter-only search - assigning equal scores to all results" 

163 ) 

164 # For filter-only searches, assign equal scores to all results 

165 scores = np.ones(len(documents)) 

166 else: 

167 # Use BM25 scoring for text queries, offloaded to a thread 

168 search_query = parsed_query.text_query if parsed_query.text_query else query 

169 scores = await asyncio.to_thread( 

170 self._compute_bm25_scores, documents, search_query 

171 ) 

172 

173 # Stable sort for ranking to keep original order among ties 

174 top_indices = np.array( 

175 sorted(range(len(scores)), key=lambda i: (scores[i], i), reverse=True)[ 

176 :limit 

177 ] 

178 ) 

179 

180 results = [] 

181 for idx in top_indices: 

182 if scores[idx] > 0: 

183 result = { 

184 "score": float(scores[idx]), 

185 "text": documents[idx], 

186 "metadata": metadata_list[idx], 

187 "source_type": source_types[idx], 

188 # Include extracted fields from Qdrant payload 

189 "title": titles[idx], 

190 "url": urls[idx], 

191 "document_id": document_ids[idx], 

192 "source": sources[idx], 

193 "created_at": created_ats[idx], 

194 "updated_at": updated_ats[idx], 

195 } 

196 

197 results.append(result) 

198 

199 return results 

200 

201 # Note: _build_filter method removed - now using FieldQueryParser.create_qdrant_filter() 

202 def _tokenize(self, text: str) -> list[str]: 

203 """Tokenize text using NLTK RegexpTokenizer word tokenization. 

204 

205 See: https://www.nltk.org/api/nltk.tokenize.regexp.html 

206 """ 

207 

208 if not isinstance(text, str): 

209 return [] 

210 tokenized_text: list[str] = RegexpTokenizer(r"\b\w+\b").tokenize(text) 

211 return [ 

212 self._stemmer.stem(word) 

213 for word in tokenized_text 

214 if word.lower() not in self._stop_words 

215 ] 

216 

217 def _compute_bm25_scores(self, documents: list[str], query: str) -> np.ndarray: 

218 """Compute BM25 scores for documents against the query. 

219 

220 Tokenizes documents and query with NLTK regex word tokenization. 

221 """ 

222 tokenized_docs = [self._tokenize(doc) for doc in documents] 

223 bm25 = BM25Okapi(tokenized_docs) 

224 tokenized_query = self._tokenize(query) 

225 return bm25.get_scores(tokenized_query)