Coverage for src/qdrant_loader_mcp_server/search/enhanced/kg/graph.py: 68%
122 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
1"""
2Core Knowledge Graph Implementation.
4This module implements the core knowledge graph using NetworkX with optimized
5node/edge management, indexing, and centrality calculations.
6"""
8from __future__ import annotations
10from collections import defaultdict
11from typing import Any
13import networkx as nx
15from ....utils.logging import LoggingConfig
16from .models import GraphEdge, GraphNode, NodeType, RelationshipType
18logger = LoggingConfig.get_logger(__name__)
21class KnowledgeGraph:
22 """Core knowledge graph implementation using NetworkX."""
24 def __init__(self):
25 """Initialize the knowledge graph."""
26 self.graph = nx.MultiDiGraph() # Allow multiple edges between nodes
27 self.nodes: dict[str, GraphNode] = {}
28 self.edges: dict[tuple[str, str, str], GraphEdge] = (
29 {}
30 ) # (source, target, relationship)
31 self.node_type_index: dict[NodeType, set[str]] = defaultdict(set)
32 self.entity_index: dict[str, set[str]] = defaultdict(set) # entity -> node_ids
33 self.topic_index: dict[str, set[str]] = defaultdict(set) # topic -> node_ids
35 logger.info("Initialized empty knowledge graph")
37 def add_node(self, node: GraphNode) -> bool:
38 """Add a node to the graph."""
39 try:
40 if node.id in self.nodes:
41 logger.debug(f"Node {node.id} already exists, updating")
42 # Remove stale index entries for the existing node before overwrite
43 old = self.nodes[node.id]
44 try:
45 self.node_type_index[old.node_type].discard(node.id)
46 except Exception:
47 pass
48 for old_entity in getattr(old, "entities", []):
49 try:
50 self.entity_index[old_entity.lower()].discard(node.id)
51 except Exception:
52 pass
53 for old_topic in getattr(old, "topics", []):
54 try:
55 self.topic_index[old_topic.lower()].discard(node.id)
56 except Exception:
57 pass
59 # Overwrite node object and update graph node attributes
60 self.nodes[node.id] = node
61 if node.id in self.graph:
62 try:
63 # Replace attributes with current metadata
64 self.graph.nodes[node.id].clear()
65 except Exception:
66 pass
67 self.graph.nodes[node.id].update(node.metadata)
68 else:
69 self.graph.add_node(node.id, **node.metadata)
71 # Add to indices for fast lookup
72 self.node_type_index[node.node_type].add(node.id)
74 for entity in node.entities:
75 self.entity_index[entity.lower()].add(node.id)
76 for topic in node.topics:
77 self.topic_index[topic.lower()].add(node.id)
79 logger.debug(f"Added {node.node_type.value} node: {node.id}")
80 return True
82 except Exception as e:
83 logger.error(f"Failed to add node {node.id}: {e}")
84 return False
86 def add_edge(self, edge: GraphEdge) -> bool:
87 """Add an edge to the graph."""
88 try:
89 if edge.source_id not in self.nodes or edge.target_id not in self.nodes:
90 logger.warning(
91 f"Edge {edge.source_id} -> {edge.target_id}: missing nodes"
92 )
93 return False
95 edge_key = (edge.source_id, edge.target_id, edge.relationship_type.value)
96 self.edges[edge_key] = edge
98 self.graph.add_edge(
99 edge.source_id,
100 edge.target_id,
101 key=edge.relationship_type.value,
102 weight=edge.weight,
103 relationship=edge.relationship_type.value,
104 confidence=edge.confidence,
105 **edge.metadata,
106 )
108 logger.debug(
109 f"Added edge: {edge.source_id} --{edge.relationship_type.value}--> {edge.target_id}"
110 )
111 return True
113 except Exception as e:
114 logger.error(
115 f"Failed to add edge {edge.source_id} -> {edge.target_id}: {e}"
116 )
117 return False
119 def find_nodes_by_type(self, node_type: NodeType) -> list[GraphNode]:
120 """Find all nodes of a specific type."""
121 return [self.nodes[node_id] for node_id in self.node_type_index[node_type]]
123 def find_nodes_by_entity(self, entity: str) -> list[GraphNode]:
124 """Find all nodes containing a specific entity."""
125 node_ids = self.entity_index.get(entity.lower(), set())
126 return [self.nodes[node_id] for node_id in node_ids]
128 def find_nodes_by_topic(self, topic: str) -> list[GraphNode]:
129 """Find all nodes discussing a specific topic."""
130 node_ids = self.topic_index.get(topic.lower(), set())
131 return [self.nodes[node_id] for node_id in node_ids]
133 def calculate_centrality_scores(self):
134 """Calculate centrality scores for all nodes."""
135 try:
136 if len(self.graph.nodes) == 0:
137 return
139 # Calculate different centrality metrics
140 degree_centrality = nx.degree_centrality(self.graph)
141 betweenness_centrality = nx.betweenness_centrality(self.graph)
143 # For directed graphs, calculate hub and authority scores
144 # Create a simple DiGraph view to ensure compatibility with HITS
145 try:
146 simple_digraph = nx.DiGraph(self.graph)
147 hub_scores, authority_scores = nx.hits(simple_digraph, max_iter=100)
148 except nx.PowerIterationFailedConvergence:
149 logger.warning(
150 "HITS algorithm failed to converge, using default scores"
151 )
152 hub_scores = dict.fromkeys(self.graph.nodes, 0.0)
153 authority_scores = dict.fromkeys(self.graph.nodes, 0.0)
155 # Update node objects with calculated scores
156 for node_id, node in self.nodes.items():
157 node.centrality_score = (
158 degree_centrality.get(node_id, 0.0) * 0.4
159 + betweenness_centrality.get(node_id, 0.0) * 0.6
160 )
161 node.hub_score = hub_scores.get(node_id, 0.0)
162 node.authority_score = authority_scores.get(node_id, 0.0)
164 logger.info(f"Calculated centrality scores for {len(self.nodes)} nodes")
166 except Exception as e:
167 logger.error(f"Failed to calculate centrality scores: {e}")
169 def get_neighbors(
170 self, node_id: str, relationship_types: list[RelationshipType] | None = None
171 ) -> list[tuple[str, GraphEdge]]:
172 """Get neighboring nodes with their connecting edges."""
173 neighbors = []
175 if node_id not in self.graph:
176 return neighbors
178 for neighbor_id in self.graph.neighbors(node_id):
179 # Get all edges between these nodes
180 edge_data = self.graph.get_edge_data(node_id, neighbor_id)
181 if edge_data:
182 for _key, data in edge_data.items():
183 rel_value = data.get("relationship")
184 if rel_value is None:
185 logger.debug(
186 "Skipping edge with missing relationship metadata: %s -> %s",
187 node_id,
188 neighbor_id,
189 )
190 continue
191 try:
192 relationship = RelationshipType(rel_value)
193 except ValueError:
194 logger.warning(
195 "Skipping edge with invalid relationship value '%s': %s -> %s",
196 rel_value,
197 node_id,
198 neighbor_id,
199 )
200 continue
201 if relationship_types is None or relationship in relationship_types:
202 edge_key = (node_id, neighbor_id, relationship.value)
203 if edge_key in self.edges:
204 neighbors.append((neighbor_id, self.edges[edge_key]))
206 return neighbors
208 def get_statistics(self) -> dict[str, Any]:
209 """Get graph statistics."""
210 stats = {
211 "total_nodes": len(self.nodes),
212 "total_edges": len(self.edges),
213 "node_types": {
214 node_type.value: len(nodes)
215 for node_type, nodes in self.node_type_index.items()
216 },
217 "relationship_types": {},
218 "connected_components": nx.number_weakly_connected_components(self.graph),
219 "avg_degree": sum(dict(self.graph.degree()).values())
220 / max(len(self.graph.nodes), 1),
221 }
223 # Count relationship types
224 for edge in self.edges.values():
225 rel_type = edge.relationship_type.value
226 stats["relationship_types"][rel_type] = (
227 stats["relationship_types"].get(rel_type, 0) + 1
228 )
230 return stats