Coverage for src / qdrant_loader_mcp_server / search / enhanced / kg / graph.py: 68%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 04:51 +0000

1""" 

2Core Knowledge Graph Implementation. 

3 

4This module implements the core knowledge graph using NetworkX with optimized 

5node/edge management, indexing, and centrality calculations. 

6""" 

7 

8from __future__ import annotations 

9 

10from collections import defaultdict 

11from typing import Any 

12 

13# Note: networkx is imported at module level because: 

14# 1. This module (kg/graph) is only imported when KnowledgeGraph is needed 

15# 2. Tests need to be able to patch nx functions 

16# The lazy loading happens at a higher level 

17import networkx as nx 

18 

19from ....utils.logging import LoggingConfig 

20from .models import GraphEdge, GraphNode, NodeType, RelationshipType 

21 

22logger = LoggingConfig.get_logger(__name__) 

23 

24 

25class KnowledgeGraph: 

26 """Core knowledge graph implementation using NetworkX.""" 

27 

28 def __init__(self): 

29 """Initialize the knowledge graph.""" 

30 self.graph = nx.MultiDiGraph() # Allow multiple edges between nodes 

31 self.nodes: dict[str, GraphNode] = {} 

32 self.edges: dict[tuple[str, str, str], GraphEdge] = ( 

33 {} 

34 ) # (source, target, relationship) 

35 self.node_type_index: dict[NodeType, set[str]] = defaultdict(set) 

36 self.entity_index: dict[str, set[str]] = defaultdict(set) # entity -> node_ids 

37 self.topic_index: dict[str, set[str]] = defaultdict(set) # topic -> node_ids 

38 

39 logger.info("Initialized empty knowledge graph") 

40 

41 def add_node(self, node: GraphNode) -> bool: 

42 """Add a node to the graph.""" 

43 try: 

44 if node.id in self.nodes: 

45 logger.debug(f"Node {node.id} already exists, updating") 

46 # Remove stale index entries for the existing node before overwrite 

47 old = self.nodes[node.id] 

48 try: 

49 self.node_type_index[old.node_type].discard(node.id) 

50 except Exception: 

51 pass 

52 for old_entity in getattr(old, "entities", []): 

53 try: 

54 self.entity_index[old_entity.lower()].discard(node.id) 

55 except Exception: 

56 pass 

57 for old_topic in getattr(old, "topics", []): 

58 try: 

59 self.topic_index[old_topic.lower()].discard(node.id) 

60 except Exception: 

61 pass 

62 

63 # Overwrite node object and update graph node attributes 

64 self.nodes[node.id] = node 

65 if node.id in self.graph: 

66 try: 

67 # Replace attributes with current metadata 

68 self.graph.nodes[node.id].clear() 

69 except Exception: 

70 pass 

71 self.graph.nodes[node.id].update(node.metadata) 

72 else: 

73 self.graph.add_node(node.id, **node.metadata) 

74 

75 # Add to indices for fast lookup 

76 self.node_type_index[node.node_type].add(node.id) 

77 

78 for entity in node.entities: 

79 self.entity_index[entity.lower()].add(node.id) 

80 for topic in node.topics: 

81 self.topic_index[topic.lower()].add(node.id) 

82 

83 logger.debug(f"Added {node.node_type.value} node: {node.id}") 

84 return True 

85 

86 except Exception as e: 

87 logger.error(f"Failed to add node {node.id}: {e}") 

88 return False 

89 

90 def add_edge(self, edge: GraphEdge) -> bool: 

91 """Add an edge to the graph.""" 

92 try: 

93 if edge.source_id not in self.nodes or edge.target_id not in self.nodes: 

94 logger.warning( 

95 f"Edge {edge.source_id} -> {edge.target_id}: missing nodes" 

96 ) 

97 return False 

98 

99 edge_key = (edge.source_id, edge.target_id, edge.relationship_type.value) 

100 self.edges[edge_key] = edge 

101 

102 self.graph.add_edge( 

103 edge.source_id, 

104 edge.target_id, 

105 key=edge.relationship_type.value, 

106 weight=edge.weight, 

107 relationship=edge.relationship_type.value, 

108 confidence=edge.confidence, 

109 **edge.metadata, 

110 ) 

111 

112 logger.debug( 

113 f"Added edge: {edge.source_id} --{edge.relationship_type.value}--> {edge.target_id}" 

114 ) 

115 return True 

116 

117 except Exception as e: 

118 logger.error( 

119 f"Failed to add edge {edge.source_id} -> {edge.target_id}: {e}" 

120 ) 

121 return False 

122 

123 def find_nodes_by_type(self, node_type: NodeType) -> list[GraphNode]: 

124 """Find all nodes of a specific type.""" 

125 return [self.nodes[node_id] for node_id in self.node_type_index[node_type]] 

126 

127 def find_nodes_by_entity(self, entity: str) -> list[GraphNode]: 

128 """Find all nodes containing a specific entity.""" 

129 node_ids = self.entity_index.get(entity.lower(), set()) 

130 return [self.nodes[node_id] for node_id in node_ids] 

131 

132 def find_nodes_by_topic(self, topic: str) -> list[GraphNode]: 

133 """Find all nodes discussing a specific topic.""" 

134 node_ids = self.topic_index.get(topic.lower(), set()) 

135 return [self.nodes[node_id] for node_id in node_ids] 

136 

137 def calculate_centrality_scores(self): 

138 """Calculate centrality scores for all nodes.""" 

139 try: 

140 if len(self.graph.nodes) == 0: 

141 return 

142 

143 # Calculate different centrality metrics 

144 degree_centrality = nx.degree_centrality(self.graph) 

145 betweenness_centrality = nx.betweenness_centrality(self.graph) 

146 

147 # For directed graphs, calculate hub and authority scores 

148 # Create a simple DiGraph view to ensure compatibility with HITS 

149 try: 

150 simple_digraph = nx.DiGraph(self.graph) 

151 hub_scores, authority_scores = nx.hits(simple_digraph, max_iter=100) 

152 except nx.PowerIterationFailedConvergence: 

153 logger.warning( 

154 "HITS algorithm failed to converge, using default scores" 

155 ) 

156 hub_scores = dict.fromkeys(self.graph.nodes, 0.0) 

157 authority_scores = dict.fromkeys(self.graph.nodes, 0.0) 

158 

159 # Update node objects with calculated scores 

160 for node_id, node in self.nodes.items(): 

161 node.centrality_score = ( 

162 degree_centrality.get(node_id, 0.0) * 0.4 

163 + betweenness_centrality.get(node_id, 0.0) * 0.6 

164 ) 

165 node.hub_score = hub_scores.get(node_id, 0.0) 

166 node.authority_score = authority_scores.get(node_id, 0.0) 

167 

168 logger.info(f"Calculated centrality scores for {len(self.nodes)} nodes") 

169 

170 except Exception as e: 

171 logger.error(f"Failed to calculate centrality scores: {e}") 

172 

173 def get_neighbors( 

174 self, node_id: str, relationship_types: list[RelationshipType] | None = None 

175 ) -> list[tuple[str, GraphEdge]]: 

176 """Get neighboring nodes with their connecting edges.""" 

177 neighbors = [] 

178 

179 if node_id not in self.graph: 

180 return neighbors 

181 

182 for neighbor_id in self.graph.neighbors(node_id): 

183 # Get all edges between these nodes 

184 edge_data = self.graph.get_edge_data(node_id, neighbor_id) 

185 if edge_data: 

186 for _key, data in edge_data.items(): 

187 rel_value = data.get("relationship") 

188 if rel_value is None: 

189 logger.debug( 

190 "Skipping edge with missing relationship metadata: %s -> %s", 

191 node_id, 

192 neighbor_id, 

193 ) 

194 continue 

195 try: 

196 relationship = RelationshipType(rel_value) 

197 except ValueError: 

198 logger.warning( 

199 "Skipping edge with invalid relationship value '%s': %s -> %s", 

200 rel_value, 

201 node_id, 

202 neighbor_id, 

203 ) 

204 continue 

205 if relationship_types is None or relationship in relationship_types: 

206 edge_key = (node_id, neighbor_id, relationship.value) 

207 if edge_key in self.edges: 

208 neighbors.append((neighbor_id, self.edges[edge_key])) 

209 

210 return neighbors 

211 

212 def get_statistics(self) -> dict[str, Any]: 

213 """Get graph statistics.""" 

214 stats = { 

215 "total_nodes": len(self.nodes), 

216 "total_edges": len(self.edges), 

217 "node_types": { 

218 node_type.value: len(nodes) 

219 for node_type, nodes in self.node_type_index.items() 

220 }, 

221 "relationship_types": {}, 

222 "connected_components": nx.number_weakly_connected_components(self.graph), 

223 "avg_degree": sum(dict(self.graph.degree()).values()) 

224 / max(len(self.graph.nodes), 1), 

225 } 

226 

227 # Count relationship types 

228 for edge in self.edges.values(): 

229 rel_type = edge.relationship_type.value 

230 stats["relationship_types"][rel_type] = ( 

231 stats["relationship_types"].get(rel_type, 0) + 1 

232 ) 

233 

234 return stats