Coverage for src/qdrant_loader_mcp_server/search/enhanced/kg/traverser.py: 95%

131 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1""" 

2Knowledge Graph Traversal Engine. 

3 

4This module implements advanced graph traversal algorithms for multi-hop search 

5and content discovery within knowledge graphs. 

6""" 

7 

8from __future__ import annotations 

9 

10import heapq 

11from collections import deque 

12from typing import TYPE_CHECKING, Any 

13 

14if TYPE_CHECKING: 

15 from ...nlp.spacy_analyzer import QueryAnalysis, SpaCyQueryAnalyzer 

16 from .models import TraversalResult, TraversalStrategy 

17else: 

18 QueryAnalysis = Any 

19 SpaCyQueryAnalyzer = Any 

20 

21from ....utils.logging import LoggingConfig 

22from .models import TraversalResult, TraversalStrategy 

23from .utils import ( 

24 ENTITY_SIM_WEIGHT, 

25 KEYWORD_SIM_WEIGHT, 

26 TOPIC_SIM_WEIGHT, 

27 build_reasoning_path, 

28 calculate_list_similarity, 

29) 

30 

31logger = LoggingConfig.get_logger(__name__) 

32 

33 

34class GraphTraverser: 

35 """Advanced graph traversal for multi-hop search and content discovery.""" 

36 

37 def __init__( 

38 self, 

39 knowledge_graph: Any, # KnowledgeGraph - avoiding circular import 

40 spacy_analyzer: SpaCyQueryAnalyzer | None = None, 

41 ): 

42 """Initialize the graph traverser.""" 

43 self.graph = knowledge_graph 

44 # Import SpaCyQueryAnalyzer at runtime to avoid circular import 

45 if spacy_analyzer is None: 

46 try: 

47 from ...nlp.spacy_analyzer import SpaCyQueryAnalyzer 

48 except ImportError as exc: 

49 logger.exception( 

50 "SpaCyQueryAnalyzer is not available. Ensure optional NLP deps are installed (e.g., 'pip install spacy' and required models)." 

51 ) 

52 raise ImportError( 

53 "SpaCyQueryAnalyzer (and its spacy dependency) is missing. Install optional NLP extras to enable semantic traversal." 

54 ) from exc 

55 self.spacy_analyzer = SpaCyQueryAnalyzer() 

56 else: 

57 self.spacy_analyzer = spacy_analyzer 

58 logger.debug("Initialized graph traverser") 

59 

60 def traverse( 

61 self, 

62 start_nodes: list[str], 

63 query_analysis: QueryAnalysis | None = None, 

64 strategy: TraversalStrategy = TraversalStrategy.WEIGHTED, 

65 max_hops: int = 3, 

66 max_results: int = 20, 

67 min_weight: float = 0.1, 

68 ) -> list[TraversalResult]: 

69 """Traverse the graph to find related content.""" 

70 

71 results = [] 

72 

73 for start_node_id in start_nodes: 

74 if start_node_id not in self.graph.nodes: 

75 continue 

76 

77 # Perform traversal based on strategy 

78 if strategy == TraversalStrategy.BREADTH_FIRST: 

79 node_results = self._breadth_first_traversal( 

80 start_node_id, 

81 query_analysis, 

82 max_hops, 

83 max_results, 

84 min_weight, 

85 ) 

86 elif strategy == TraversalStrategy.WEIGHTED: 

87 node_results = self._weighted_traversal( 

88 start_node_id, 

89 query_analysis, 

90 max_hops, 

91 max_results, 

92 min_weight, 

93 ) 

94 elif strategy == TraversalStrategy.CENTRALITY: 

95 node_results = self._centrality_traversal( 

96 start_node_id, 

97 query_analysis, 

98 max_hops, 

99 max_results, 

100 min_weight, 

101 ) 

102 elif strategy == TraversalStrategy.SEMANTIC: 

103 node_results = self._semantic_traversal( 

104 start_node_id, 

105 query_analysis, 

106 max_hops, 

107 max_results, 

108 min_weight, 

109 ) 

110 else: 

111 node_results = self._breadth_first_traversal( 

112 start_node_id, 

113 query_analysis, 

114 max_hops, 

115 max_results, 

116 min_weight, 

117 ) 

118 

119 results.extend(node_results) 

120 

121 # Sort by semantic score and total weight 

122 results.sort(key=lambda r: (r.semantic_score, r.total_weight), reverse=True) 

123 return results[:max_results] 

124 

125 def _breadth_first_traversal( 

126 self, 

127 start_node_id: str, 

128 query_analysis: QueryAnalysis | None, 

129 max_hops: int, 

130 max_results: int, 

131 min_weight: float, 

132 ) -> list[TraversalResult]: 

133 """Breadth-first traversal implementation.""" 

134 

135 results = [] 

136 queue = deque( 

137 [(start_node_id, [], [], 0.0, 0)] 

138 ) # (node_id, path, edges, weight, hops) 

139 local_visited = set() 

140 

141 while queue and len(results) < max_results: 

142 node_id, path, edges, total_weight, hops = queue.popleft() 

143 

144 if node_id in local_visited or hops > max_hops: 

145 continue 

146 

147 local_visited.add(node_id) 

148 

149 # Create traversal result 

150 if node_id != start_node_id: # Don't include the starting node 

151 semantic_score = self._calculate_semantic_score(node_id, query_analysis) 

152 reasoning_path = build_reasoning_path(edges, self.graph.nodes) 

153 

154 result = TraversalResult( 

155 path=path + [node_id], 

156 nodes=[self.graph.nodes[nid] for nid in path + [node_id]], 

157 edges=edges, 

158 total_weight=total_weight, 

159 semantic_score=semantic_score, 

160 hop_count=hops, 

161 reasoning_path=reasoning_path, 

162 ) 

163 results.append(result) 

164 

165 # Add neighbors to queue 

166 neighbors = self.graph.get_neighbors(node_id) 

167 for neighbor_id, edge in neighbors: 

168 if neighbor_id not in local_visited and edge.weight >= min_weight: 

169 queue.append( 

170 ( 

171 neighbor_id, 

172 path + [node_id], 

173 edges + [edge], 

174 total_weight + edge.weight, 

175 hops + 1, 

176 ) 

177 ) 

178 

179 return results 

180 

181 def _weighted_traversal( 

182 self, 

183 start_node_id: str, 

184 query_analysis: QueryAnalysis | None, 

185 max_hops: int, 

186 max_results: int, 

187 min_weight: float, 

188 ) -> list[TraversalResult]: 

189 """Weighted traversal prioritizing strong relationships.""" 

190 

191 results = [] 

192 # Priority queue: (negative_weight, node_id, path, edges, weight, hops) 

193 heap = [(-1.0, start_node_id, [], [], 0.0, 0)] 

194 local_visited = set() 

195 

196 while heap and len(results) < max_results: 

197 neg_weight, node_id, path, edges, total_weight, hops = heapq.heappop(heap) 

198 

199 if node_id in local_visited or hops > max_hops: 

200 continue 

201 

202 local_visited.add(node_id) 

203 

204 # Create traversal result 

205 if node_id != start_node_id: 

206 semantic_score = self._calculate_semantic_score(node_id, query_analysis) 

207 reasoning_path = build_reasoning_path(edges, self.graph.nodes) 

208 

209 result = TraversalResult( 

210 path=path + [node_id], 

211 nodes=[self.graph.nodes[nid] for nid in path + [node_id]], 

212 edges=edges, 

213 total_weight=total_weight, 

214 semantic_score=semantic_score, 

215 hop_count=hops, 

216 reasoning_path=reasoning_path, 

217 ) 

218 results.append(result) 

219 

220 # Add neighbors to heap 

221 neighbors = self.graph.get_neighbors(node_id) 

222 for neighbor_id, edge in neighbors: 

223 if neighbor_id not in local_visited and edge.weight >= min_weight: 

224 new_weight = total_weight + edge.weight 

225 heapq.heappush( 

226 heap, 

227 ( 

228 -new_weight, # Negative for max-heap behavior 

229 neighbor_id, 

230 path + [node_id], 

231 edges + [edge], 

232 new_weight, 

233 hops + 1, 

234 ), 

235 ) 

236 

237 return results 

238 

239 def _centrality_traversal( 

240 self, 

241 start_node_id: str, 

242 query_analysis: QueryAnalysis | None, 

243 max_hops: int, 

244 max_results: int, 

245 min_weight: float, 

246 ) -> list[TraversalResult]: 

247 """Traversal prioritizing high-centrality nodes.""" 

248 

249 results = [] 

250 # Priority queue: (negative_centrality, node_id, path, edges, weight, hops) 

251 start_centrality = self.graph.nodes[start_node_id].centrality_score 

252 heap = [(-start_centrality, start_node_id, [], [], 0.0, 0)] 

253 local_visited = set() 

254 

255 while heap and len(results) < max_results: 

256 neg_centrality, node_id, path, edges, total_weight, hops = heapq.heappop( 

257 heap 

258 ) 

259 

260 if node_id in local_visited or hops > max_hops: 

261 continue 

262 

263 local_visited.add(node_id) 

264 

265 # Create traversal result 

266 if node_id != start_node_id: 

267 semantic_score = self._calculate_semantic_score(node_id, query_analysis) 

268 reasoning_path = build_reasoning_path(edges, self.graph.nodes) 

269 

270 result = TraversalResult( 

271 path=path + [node_id], 

272 nodes=[self.graph.nodes[nid] for nid in path + [node_id]], 

273 edges=edges, 

274 total_weight=total_weight, 

275 semantic_score=semantic_score, 

276 hop_count=hops, 

277 reasoning_path=reasoning_path, 

278 ) 

279 results.append(result) 

280 

281 # Add neighbors to heap 

282 neighbors = self.graph.get_neighbors(node_id) 

283 for neighbor_id, edge in neighbors: 

284 if neighbor_id not in local_visited and edge.weight >= min_weight: 

285 neighbor_centrality = self.graph.nodes[neighbor_id].centrality_score 

286 heapq.heappush( 

287 heap, 

288 ( 

289 -neighbor_centrality, 

290 neighbor_id, 

291 path + [node_id], 

292 edges + [edge], 

293 total_weight + edge.weight, 

294 hops + 1, 

295 ), 

296 ) 

297 

298 return results 

299 

300 def _semantic_traversal( 

301 self, 

302 start_node_id: str, 

303 query_analysis: QueryAnalysis | None, 

304 max_hops: int, 

305 max_results: int, 

306 min_weight: float, 

307 ) -> list[TraversalResult]: 

308 """Traversal prioritizing semantic similarity to query.""" 

309 

310 if not query_analysis: 

311 return self._breadth_first_traversal( 

312 start_node_id, 

313 query_analysis, 

314 max_hops, 

315 max_results, 

316 min_weight, 

317 ) 

318 

319 results = [] 

320 # Priority queue: (negative_semantic_score, node_id, path, edges, weight, hops) 

321 start_score = self._calculate_semantic_score(start_node_id, query_analysis) 

322 heap = [(-start_score, start_node_id, [], [], 0.0, 0)] 

323 local_visited = set() 

324 

325 while heap and len(results) < max_results: 

326 neg_score, node_id, path, edges, total_weight, hops = heapq.heappop(heap) 

327 

328 if node_id in local_visited or hops > max_hops: 

329 continue 

330 

331 local_visited.add(node_id) 

332 

333 # Create traversal result 

334 if node_id != start_node_id: 

335 semantic_score = -neg_score # Convert back from negative 

336 reasoning_path = build_reasoning_path(edges, self.graph.nodes) 

337 

338 result = TraversalResult( 

339 path=path + [node_id], 

340 nodes=[self.graph.nodes[nid] for nid in path + [node_id]], 

341 edges=edges, 

342 total_weight=total_weight, 

343 semantic_score=semantic_score, 

344 hop_count=hops, 

345 reasoning_path=reasoning_path, 

346 ) 

347 results.append(result) 

348 

349 # Add neighbors to heap 

350 neighbors = self.graph.get_neighbors(node_id) 

351 for neighbor_id, edge in neighbors: 

352 if neighbor_id not in local_visited and edge.weight >= min_weight: 

353 neighbor_score = self._calculate_semantic_score( 

354 neighbor_id, query_analysis 

355 ) 

356 heapq.heappush( 

357 heap, 

358 ( 

359 -neighbor_score, 

360 neighbor_id, 

361 path + [node_id], 

362 edges + [edge], 

363 total_weight + edge.weight, 

364 hops + 1, 

365 ), 

366 ) 

367 

368 return results 

369 

370 def _calculate_semantic_score( 

371 self, node_id: str, query_analysis: QueryAnalysis | None 

372 ) -> float: 

373 """Calculate semantic similarity between node and query.""" 

374 if not query_analysis: 

375 return 0.0 

376 

377 node = self.graph.nodes[node_id] 

378 

379 # Calculate similarity based on entities, topics, and keywords 

380 entity_similarity = calculate_list_similarity( 

381 query_analysis.entities, [(e, "") for e in node.entities] 

382 ) 

383 

384 topic_similarity = calculate_list_similarity( 

385 [(t, "") for t in query_analysis.main_concepts], 

386 [(t, "") for t in node.topics], 

387 ) 

388 

389 keyword_similarity = calculate_list_similarity( 

390 [(k, "") for k in query_analysis.semantic_keywords], 

391 [(k, "") for k in node.keywords], 

392 ) 

393 

394 # Weighted combination 

395 total_score = ( 

396 entity_similarity * ENTITY_SIM_WEIGHT 

397 + topic_similarity * TOPIC_SIM_WEIGHT 

398 + keyword_similarity * KEYWORD_SIM_WEIGHT 

399 ) 

400 

401 return total_score