Coverage for src/qdrant_loader_mcp_server/search/enhanced/kg/traverser.py: 95%
131 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
1"""
2Knowledge Graph Traversal Engine.
4This module implements advanced graph traversal algorithms for multi-hop search
5and content discovery within knowledge graphs.
6"""
8from __future__ import annotations
10import heapq
11from collections import deque
12from typing import TYPE_CHECKING, Any
14if TYPE_CHECKING:
15 from ...nlp.spacy_analyzer import QueryAnalysis, SpaCyQueryAnalyzer
16 from .models import TraversalResult, TraversalStrategy
17else:
18 QueryAnalysis = Any
19 SpaCyQueryAnalyzer = Any
21from ....utils.logging import LoggingConfig
22from .models import TraversalResult, TraversalStrategy
23from .utils import (
24 ENTITY_SIM_WEIGHT,
25 KEYWORD_SIM_WEIGHT,
26 TOPIC_SIM_WEIGHT,
27 build_reasoning_path,
28 calculate_list_similarity,
29)
31logger = LoggingConfig.get_logger(__name__)
34class GraphTraverser:
35 """Advanced graph traversal for multi-hop search and content discovery."""
37 def __init__(
38 self,
39 knowledge_graph: Any, # KnowledgeGraph - avoiding circular import
40 spacy_analyzer: SpaCyQueryAnalyzer | None = None,
41 ):
42 """Initialize the graph traverser."""
43 self.graph = knowledge_graph
44 # Import SpaCyQueryAnalyzer at runtime to avoid circular import
45 if spacy_analyzer is None:
46 try:
47 from ...nlp.spacy_analyzer import SpaCyQueryAnalyzer
48 except ImportError as exc:
49 logger.exception(
50 "SpaCyQueryAnalyzer is not available. Ensure optional NLP deps are installed (e.g., 'pip install spacy' and required models)."
51 )
52 raise ImportError(
53 "SpaCyQueryAnalyzer (and its spacy dependency) is missing. Install optional NLP extras to enable semantic traversal."
54 ) from exc
55 self.spacy_analyzer = SpaCyQueryAnalyzer()
56 else:
57 self.spacy_analyzer = spacy_analyzer
58 logger.debug("Initialized graph traverser")
60 def traverse(
61 self,
62 start_nodes: list[str],
63 query_analysis: QueryAnalysis | None = None,
64 strategy: TraversalStrategy = TraversalStrategy.WEIGHTED,
65 max_hops: int = 3,
66 max_results: int = 20,
67 min_weight: float = 0.1,
68 ) -> list[TraversalResult]:
69 """Traverse the graph to find related content."""
71 results = []
73 for start_node_id in start_nodes:
74 if start_node_id not in self.graph.nodes:
75 continue
77 # Perform traversal based on strategy
78 if strategy == TraversalStrategy.BREADTH_FIRST:
79 node_results = self._breadth_first_traversal(
80 start_node_id,
81 query_analysis,
82 max_hops,
83 max_results,
84 min_weight,
85 )
86 elif strategy == TraversalStrategy.WEIGHTED:
87 node_results = self._weighted_traversal(
88 start_node_id,
89 query_analysis,
90 max_hops,
91 max_results,
92 min_weight,
93 )
94 elif strategy == TraversalStrategy.CENTRALITY:
95 node_results = self._centrality_traversal(
96 start_node_id,
97 query_analysis,
98 max_hops,
99 max_results,
100 min_weight,
101 )
102 elif strategy == TraversalStrategy.SEMANTIC:
103 node_results = self._semantic_traversal(
104 start_node_id,
105 query_analysis,
106 max_hops,
107 max_results,
108 min_weight,
109 )
110 else:
111 node_results = self._breadth_first_traversal(
112 start_node_id,
113 query_analysis,
114 max_hops,
115 max_results,
116 min_weight,
117 )
119 results.extend(node_results)
121 # Sort by semantic score and total weight
122 results.sort(key=lambda r: (r.semantic_score, r.total_weight), reverse=True)
123 return results[:max_results]
125 def _breadth_first_traversal(
126 self,
127 start_node_id: str,
128 query_analysis: QueryAnalysis | None,
129 max_hops: int,
130 max_results: int,
131 min_weight: float,
132 ) -> list[TraversalResult]:
133 """Breadth-first traversal implementation."""
135 results = []
136 queue = deque(
137 [(start_node_id, [], [], 0.0, 0)]
138 ) # (node_id, path, edges, weight, hops)
139 local_visited = set()
141 while queue and len(results) < max_results:
142 node_id, path, edges, total_weight, hops = queue.popleft()
144 if node_id in local_visited or hops > max_hops:
145 continue
147 local_visited.add(node_id)
149 # Create traversal result
150 if node_id != start_node_id: # Don't include the starting node
151 semantic_score = self._calculate_semantic_score(node_id, query_analysis)
152 reasoning_path = build_reasoning_path(edges, self.graph.nodes)
154 result = TraversalResult(
155 path=path + [node_id],
156 nodes=[self.graph.nodes[nid] for nid in path + [node_id]],
157 edges=edges,
158 total_weight=total_weight,
159 semantic_score=semantic_score,
160 hop_count=hops,
161 reasoning_path=reasoning_path,
162 )
163 results.append(result)
165 # Add neighbors to queue
166 neighbors = self.graph.get_neighbors(node_id)
167 for neighbor_id, edge in neighbors:
168 if neighbor_id not in local_visited and edge.weight >= min_weight:
169 queue.append(
170 (
171 neighbor_id,
172 path + [node_id],
173 edges + [edge],
174 total_weight + edge.weight,
175 hops + 1,
176 )
177 )
179 return results
181 def _weighted_traversal(
182 self,
183 start_node_id: str,
184 query_analysis: QueryAnalysis | None,
185 max_hops: int,
186 max_results: int,
187 min_weight: float,
188 ) -> list[TraversalResult]:
189 """Weighted traversal prioritizing strong relationships."""
191 results = []
192 # Priority queue: (negative_weight, node_id, path, edges, weight, hops)
193 heap = [(-1.0, start_node_id, [], [], 0.0, 0)]
194 local_visited = set()
196 while heap and len(results) < max_results:
197 neg_weight, node_id, path, edges, total_weight, hops = heapq.heappop(heap)
199 if node_id in local_visited or hops > max_hops:
200 continue
202 local_visited.add(node_id)
204 # Create traversal result
205 if node_id != start_node_id:
206 semantic_score = self._calculate_semantic_score(node_id, query_analysis)
207 reasoning_path = build_reasoning_path(edges, self.graph.nodes)
209 result = TraversalResult(
210 path=path + [node_id],
211 nodes=[self.graph.nodes[nid] for nid in path + [node_id]],
212 edges=edges,
213 total_weight=total_weight,
214 semantic_score=semantic_score,
215 hop_count=hops,
216 reasoning_path=reasoning_path,
217 )
218 results.append(result)
220 # Add neighbors to heap
221 neighbors = self.graph.get_neighbors(node_id)
222 for neighbor_id, edge in neighbors:
223 if neighbor_id not in local_visited and edge.weight >= min_weight:
224 new_weight = total_weight + edge.weight
225 heapq.heappush(
226 heap,
227 (
228 -new_weight, # Negative for max-heap behavior
229 neighbor_id,
230 path + [node_id],
231 edges + [edge],
232 new_weight,
233 hops + 1,
234 ),
235 )
237 return results
239 def _centrality_traversal(
240 self,
241 start_node_id: str,
242 query_analysis: QueryAnalysis | None,
243 max_hops: int,
244 max_results: int,
245 min_weight: float,
246 ) -> list[TraversalResult]:
247 """Traversal prioritizing high-centrality nodes."""
249 results = []
250 # Priority queue: (negative_centrality, node_id, path, edges, weight, hops)
251 start_centrality = self.graph.nodes[start_node_id].centrality_score
252 heap = [(-start_centrality, start_node_id, [], [], 0.0, 0)]
253 local_visited = set()
255 while heap and len(results) < max_results:
256 neg_centrality, node_id, path, edges, total_weight, hops = heapq.heappop(
257 heap
258 )
260 if node_id in local_visited or hops > max_hops:
261 continue
263 local_visited.add(node_id)
265 # Create traversal result
266 if node_id != start_node_id:
267 semantic_score = self._calculate_semantic_score(node_id, query_analysis)
268 reasoning_path = build_reasoning_path(edges, self.graph.nodes)
270 result = TraversalResult(
271 path=path + [node_id],
272 nodes=[self.graph.nodes[nid] for nid in path + [node_id]],
273 edges=edges,
274 total_weight=total_weight,
275 semantic_score=semantic_score,
276 hop_count=hops,
277 reasoning_path=reasoning_path,
278 )
279 results.append(result)
281 # Add neighbors to heap
282 neighbors = self.graph.get_neighbors(node_id)
283 for neighbor_id, edge in neighbors:
284 if neighbor_id not in local_visited and edge.weight >= min_weight:
285 neighbor_centrality = self.graph.nodes[neighbor_id].centrality_score
286 heapq.heappush(
287 heap,
288 (
289 -neighbor_centrality,
290 neighbor_id,
291 path + [node_id],
292 edges + [edge],
293 total_weight + edge.weight,
294 hops + 1,
295 ),
296 )
298 return results
300 def _semantic_traversal(
301 self,
302 start_node_id: str,
303 query_analysis: QueryAnalysis | None,
304 max_hops: int,
305 max_results: int,
306 min_weight: float,
307 ) -> list[TraversalResult]:
308 """Traversal prioritizing semantic similarity to query."""
310 if not query_analysis:
311 return self._breadth_first_traversal(
312 start_node_id,
313 query_analysis,
314 max_hops,
315 max_results,
316 min_weight,
317 )
319 results = []
320 # Priority queue: (negative_semantic_score, node_id, path, edges, weight, hops)
321 start_score = self._calculate_semantic_score(start_node_id, query_analysis)
322 heap = [(-start_score, start_node_id, [], [], 0.0, 0)]
323 local_visited = set()
325 while heap and len(results) < max_results:
326 neg_score, node_id, path, edges, total_weight, hops = heapq.heappop(heap)
328 if node_id in local_visited or hops > max_hops:
329 continue
331 local_visited.add(node_id)
333 # Create traversal result
334 if node_id != start_node_id:
335 semantic_score = -neg_score # Convert back from negative
336 reasoning_path = build_reasoning_path(edges, self.graph.nodes)
338 result = TraversalResult(
339 path=path + [node_id],
340 nodes=[self.graph.nodes[nid] for nid in path + [node_id]],
341 edges=edges,
342 total_weight=total_weight,
343 semantic_score=semantic_score,
344 hop_count=hops,
345 reasoning_path=reasoning_path,
346 )
347 results.append(result)
349 # Add neighbors to heap
350 neighbors = self.graph.get_neighbors(node_id)
351 for neighbor_id, edge in neighbors:
352 if neighbor_id not in local_visited and edge.weight >= min_weight:
353 neighbor_score = self._calculate_semantic_score(
354 neighbor_id, query_analysis
355 )
356 heapq.heappush(
357 heap,
358 (
359 -neighbor_score,
360 neighbor_id,
361 path + [node_id],
362 edges + [edge],
363 total_weight + edge.weight,
364 hops + 1,
365 ),
366 )
368 return results
370 def _calculate_semantic_score(
371 self, node_id: str, query_analysis: QueryAnalysis | None
372 ) -> float:
373 """Calculate semantic similarity between node and query."""
374 if not query_analysis:
375 return 0.0
377 node = self.graph.nodes[node_id]
379 # Calculate similarity based on entities, topics, and keywords
380 entity_similarity = calculate_list_similarity(
381 query_analysis.entities, [(e, "") for e in node.entities]
382 )
384 topic_similarity = calculate_list_similarity(
385 [(t, "") for t in query_analysis.main_concepts],
386 [(t, "") for t in node.topics],
387 )
389 keyword_similarity = calculate_list_similarity(
390 [(k, "") for k in query_analysis.semantic_keywords],
391 [(k, "") for k in node.keywords],
392 )
394 # Weighted combination
395 total_score = (
396 entity_similarity * ENTITY_SIM_WEIGHT
397 + topic_similarity * TOPIC_SIM_WEIGHT
398 + keyword_similarity * KEYWORD_SIM_WEIGHT
399 )
401 return total_score