Coverage for src/qdrant_loader_mcp_server/search/enhanced/kg/graph.py: 68%

122 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1""" 

2Core Knowledge Graph Implementation. 

3 

4This module implements the core knowledge graph using NetworkX with optimized 

5node/edge management, indexing, and centrality calculations. 

6""" 

7 

8from __future__ import annotations 

9 

10from collections import defaultdict 

11from typing import Any 

12 

13import networkx as nx 

14 

15from ....utils.logging import LoggingConfig 

16from .models import GraphEdge, GraphNode, NodeType, RelationshipType 

17 

18logger = LoggingConfig.get_logger(__name__) 

19 

20 

21class KnowledgeGraph: 

22 """Core knowledge graph implementation using NetworkX.""" 

23 

24 def __init__(self): 

25 """Initialize the knowledge graph.""" 

26 self.graph = nx.MultiDiGraph() # Allow multiple edges between nodes 

27 self.nodes: dict[str, GraphNode] = {} 

28 self.edges: dict[tuple[str, str, str], GraphEdge] = ( 

29 {} 

30 ) # (source, target, relationship) 

31 self.node_type_index: dict[NodeType, set[str]] = defaultdict(set) 

32 self.entity_index: dict[str, set[str]] = defaultdict(set) # entity -> node_ids 

33 self.topic_index: dict[str, set[str]] = defaultdict(set) # topic -> node_ids 

34 

35 logger.info("Initialized empty knowledge graph") 

36 

37 def add_node(self, node: GraphNode) -> bool: 

38 """Add a node to the graph.""" 

39 try: 

40 if node.id in self.nodes: 

41 logger.debug(f"Node {node.id} already exists, updating") 

42 # Remove stale index entries for the existing node before overwrite 

43 old = self.nodes[node.id] 

44 try: 

45 self.node_type_index[old.node_type].discard(node.id) 

46 except Exception: 

47 pass 

48 for old_entity in getattr(old, "entities", []): 

49 try: 

50 self.entity_index[old_entity.lower()].discard(node.id) 

51 except Exception: 

52 pass 

53 for old_topic in getattr(old, "topics", []): 

54 try: 

55 self.topic_index[old_topic.lower()].discard(node.id) 

56 except Exception: 

57 pass 

58 

59 # Overwrite node object and update graph node attributes 

60 self.nodes[node.id] = node 

61 if node.id in self.graph: 

62 try: 

63 # Replace attributes with current metadata 

64 self.graph.nodes[node.id].clear() 

65 except Exception: 

66 pass 

67 self.graph.nodes[node.id].update(node.metadata) 

68 else: 

69 self.graph.add_node(node.id, **node.metadata) 

70 

71 # Add to indices for fast lookup 

72 self.node_type_index[node.node_type].add(node.id) 

73 

74 for entity in node.entities: 

75 self.entity_index[entity.lower()].add(node.id) 

76 for topic in node.topics: 

77 self.topic_index[topic.lower()].add(node.id) 

78 

79 logger.debug(f"Added {node.node_type.value} node: {node.id}") 

80 return True 

81 

82 except Exception as e: 

83 logger.error(f"Failed to add node {node.id}: {e}") 

84 return False 

85 

86 def add_edge(self, edge: GraphEdge) -> bool: 

87 """Add an edge to the graph.""" 

88 try: 

89 if edge.source_id not in self.nodes or edge.target_id not in self.nodes: 

90 logger.warning( 

91 f"Edge {edge.source_id} -> {edge.target_id}: missing nodes" 

92 ) 

93 return False 

94 

95 edge_key = (edge.source_id, edge.target_id, edge.relationship_type.value) 

96 self.edges[edge_key] = edge 

97 

98 self.graph.add_edge( 

99 edge.source_id, 

100 edge.target_id, 

101 key=edge.relationship_type.value, 

102 weight=edge.weight, 

103 relationship=edge.relationship_type.value, 

104 confidence=edge.confidence, 

105 **edge.metadata, 

106 ) 

107 

108 logger.debug( 

109 f"Added edge: {edge.source_id} --{edge.relationship_type.value}--> {edge.target_id}" 

110 ) 

111 return True 

112 

113 except Exception as e: 

114 logger.error( 

115 f"Failed to add edge {edge.source_id} -> {edge.target_id}: {e}" 

116 ) 

117 return False 

118 

119 def find_nodes_by_type(self, node_type: NodeType) -> list[GraphNode]: 

120 """Find all nodes of a specific type.""" 

121 return [self.nodes[node_id] for node_id in self.node_type_index[node_type]] 

122 

123 def find_nodes_by_entity(self, entity: str) -> list[GraphNode]: 

124 """Find all nodes containing a specific entity.""" 

125 node_ids = self.entity_index.get(entity.lower(), set()) 

126 return [self.nodes[node_id] for node_id in node_ids] 

127 

128 def find_nodes_by_topic(self, topic: str) -> list[GraphNode]: 

129 """Find all nodes discussing a specific topic.""" 

130 node_ids = self.topic_index.get(topic.lower(), set()) 

131 return [self.nodes[node_id] for node_id in node_ids] 

132 

133 def calculate_centrality_scores(self): 

134 """Calculate centrality scores for all nodes.""" 

135 try: 

136 if len(self.graph.nodes) == 0: 

137 return 

138 

139 # Calculate different centrality metrics 

140 degree_centrality = nx.degree_centrality(self.graph) 

141 betweenness_centrality = nx.betweenness_centrality(self.graph) 

142 

143 # For directed graphs, calculate hub and authority scores 

144 # Create a simple DiGraph view to ensure compatibility with HITS 

145 try: 

146 simple_digraph = nx.DiGraph(self.graph) 

147 hub_scores, authority_scores = nx.hits(simple_digraph, max_iter=100) 

148 except nx.PowerIterationFailedConvergence: 

149 logger.warning( 

150 "HITS algorithm failed to converge, using default scores" 

151 ) 

152 hub_scores = dict.fromkeys(self.graph.nodes, 0.0) 

153 authority_scores = dict.fromkeys(self.graph.nodes, 0.0) 

154 

155 # Update node objects with calculated scores 

156 for node_id, node in self.nodes.items(): 

157 node.centrality_score = ( 

158 degree_centrality.get(node_id, 0.0) * 0.4 

159 + betweenness_centrality.get(node_id, 0.0) * 0.6 

160 ) 

161 node.hub_score = hub_scores.get(node_id, 0.0) 

162 node.authority_score = authority_scores.get(node_id, 0.0) 

163 

164 logger.info(f"Calculated centrality scores for {len(self.nodes)} nodes") 

165 

166 except Exception as e: 

167 logger.error(f"Failed to calculate centrality scores: {e}") 

168 

169 def get_neighbors( 

170 self, node_id: str, relationship_types: list[RelationshipType] | None = None 

171 ) -> list[tuple[str, GraphEdge]]: 

172 """Get neighboring nodes with their connecting edges.""" 

173 neighbors = [] 

174 

175 if node_id not in self.graph: 

176 return neighbors 

177 

178 for neighbor_id in self.graph.neighbors(node_id): 

179 # Get all edges between these nodes 

180 edge_data = self.graph.get_edge_data(node_id, neighbor_id) 

181 if edge_data: 

182 for _key, data in edge_data.items(): 

183 rel_value = data.get("relationship") 

184 if rel_value is None: 

185 logger.debug( 

186 "Skipping edge with missing relationship metadata: %s -> %s", 

187 node_id, 

188 neighbor_id, 

189 ) 

190 continue 

191 try: 

192 relationship = RelationshipType(rel_value) 

193 except ValueError: 

194 logger.warning( 

195 "Skipping edge with invalid relationship value '%s': %s -> %s", 

196 rel_value, 

197 node_id, 

198 neighbor_id, 

199 ) 

200 continue 

201 if relationship_types is None or relationship in relationship_types: 

202 edge_key = (node_id, neighbor_id, relationship.value) 

203 if edge_key in self.edges: 

204 neighbors.append((neighbor_id, self.edges[edge_key])) 

205 

206 return neighbors 

207 

208 def get_statistics(self) -> dict[str, Any]: 

209 """Get graph statistics.""" 

210 stats = { 

211 "total_nodes": len(self.nodes), 

212 "total_edges": len(self.edges), 

213 "node_types": { 

214 node_type.value: len(nodes) 

215 for node_type, nodes in self.node_type_index.items() 

216 }, 

217 "relationship_types": {}, 

218 "connected_components": nx.number_weakly_connected_components(self.graph), 

219 "avg_degree": sum(dict(self.graph.degree()).values()) 

220 / max(len(self.graph.nodes), 1), 

221 } 

222 

223 # Count relationship types 

224 for edge in self.edges.values(): 

225 rel_type = edge.relationship_type.value 

226 stats["relationship_types"][rel_type] = ( 

227 stats["relationship_types"].get(rel_type, 0) + 1 

228 ) 

229 

230 return stats