Coverage for src / qdrant_loader_mcp_server / search / enhanced / kg / graph.py: 68%
122 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
1"""
2Core Knowledge Graph Implementation.
4This module implements the core knowledge graph using NetworkX with optimized
5node/edge management, indexing, and centrality calculations.
6"""
8from __future__ import annotations
10from collections import defaultdict
11from typing import Any
13# Note: networkx is imported at module level because:
14# 1. This module (kg/graph) is only imported when KnowledgeGraph is needed
15# 2. Tests need to be able to patch nx functions
16# The lazy loading happens at a higher level
17import networkx as nx
19from ....utils.logging import LoggingConfig
20from .models import GraphEdge, GraphNode, NodeType, RelationshipType
22logger = LoggingConfig.get_logger(__name__)
25class KnowledgeGraph:
26 """Core knowledge graph implementation using NetworkX."""
28 def __init__(self):
29 """Initialize the knowledge graph."""
30 self.graph = nx.MultiDiGraph() # Allow multiple edges between nodes
31 self.nodes: dict[str, GraphNode] = {}
32 self.edges: dict[tuple[str, str, str], GraphEdge] = (
33 {}
34 ) # (source, target, relationship)
35 self.node_type_index: dict[NodeType, set[str]] = defaultdict(set)
36 self.entity_index: dict[str, set[str]] = defaultdict(set) # entity -> node_ids
37 self.topic_index: dict[str, set[str]] = defaultdict(set) # topic -> node_ids
39 logger.info("Initialized empty knowledge graph")
41 def add_node(self, node: GraphNode) -> bool:
42 """Add a node to the graph."""
43 try:
44 if node.id in self.nodes:
45 logger.debug(f"Node {node.id} already exists, updating")
46 # Remove stale index entries for the existing node before overwrite
47 old = self.nodes[node.id]
48 try:
49 self.node_type_index[old.node_type].discard(node.id)
50 except Exception:
51 pass
52 for old_entity in getattr(old, "entities", []):
53 try:
54 self.entity_index[old_entity.lower()].discard(node.id)
55 except Exception:
56 pass
57 for old_topic in getattr(old, "topics", []):
58 try:
59 self.topic_index[old_topic.lower()].discard(node.id)
60 except Exception:
61 pass
63 # Overwrite node object and update graph node attributes
64 self.nodes[node.id] = node
65 if node.id in self.graph:
66 try:
67 # Replace attributes with current metadata
68 self.graph.nodes[node.id].clear()
69 except Exception:
70 pass
71 self.graph.nodes[node.id].update(node.metadata)
72 else:
73 self.graph.add_node(node.id, **node.metadata)
75 # Add to indices for fast lookup
76 self.node_type_index[node.node_type].add(node.id)
78 for entity in node.entities:
79 self.entity_index[entity.lower()].add(node.id)
80 for topic in node.topics:
81 self.topic_index[topic.lower()].add(node.id)
83 logger.debug(f"Added {node.node_type.value} node: {node.id}")
84 return True
86 except Exception as e:
87 logger.error(f"Failed to add node {node.id}: {e}")
88 return False
90 def add_edge(self, edge: GraphEdge) -> bool:
91 """Add an edge to the graph."""
92 try:
93 if edge.source_id not in self.nodes or edge.target_id not in self.nodes:
94 logger.warning(
95 f"Edge {edge.source_id} -> {edge.target_id}: missing nodes"
96 )
97 return False
99 edge_key = (edge.source_id, edge.target_id, edge.relationship_type.value)
100 self.edges[edge_key] = edge
102 self.graph.add_edge(
103 edge.source_id,
104 edge.target_id,
105 key=edge.relationship_type.value,
106 weight=edge.weight,
107 relationship=edge.relationship_type.value,
108 confidence=edge.confidence,
109 **edge.metadata,
110 )
112 logger.debug(
113 f"Added edge: {edge.source_id} --{edge.relationship_type.value}--> {edge.target_id}"
114 )
115 return True
117 except Exception as e:
118 logger.error(
119 f"Failed to add edge {edge.source_id} -> {edge.target_id}: {e}"
120 )
121 return False
123 def find_nodes_by_type(self, node_type: NodeType) -> list[GraphNode]:
124 """Find all nodes of a specific type."""
125 return [self.nodes[node_id] for node_id in self.node_type_index[node_type]]
127 def find_nodes_by_entity(self, entity: str) -> list[GraphNode]:
128 """Find all nodes containing a specific entity."""
129 node_ids = self.entity_index.get(entity.lower(), set())
130 return [self.nodes[node_id] for node_id in node_ids]
132 def find_nodes_by_topic(self, topic: str) -> list[GraphNode]:
133 """Find all nodes discussing a specific topic."""
134 node_ids = self.topic_index.get(topic.lower(), set())
135 return [self.nodes[node_id] for node_id in node_ids]
137 def calculate_centrality_scores(self):
138 """Calculate centrality scores for all nodes."""
139 try:
140 if len(self.graph.nodes) == 0:
141 return
143 # Calculate different centrality metrics
144 degree_centrality = nx.degree_centrality(self.graph)
145 betweenness_centrality = nx.betweenness_centrality(self.graph)
147 # For directed graphs, calculate hub and authority scores
148 # Create a simple DiGraph view to ensure compatibility with HITS
149 try:
150 simple_digraph = nx.DiGraph(self.graph)
151 hub_scores, authority_scores = nx.hits(simple_digraph, max_iter=100)
152 except nx.PowerIterationFailedConvergence:
153 logger.warning(
154 "HITS algorithm failed to converge, using default scores"
155 )
156 hub_scores = dict.fromkeys(self.graph.nodes, 0.0)
157 authority_scores = dict.fromkeys(self.graph.nodes, 0.0)
159 # Update node objects with calculated scores
160 for node_id, node in self.nodes.items():
161 node.centrality_score = (
162 degree_centrality.get(node_id, 0.0) * 0.4
163 + betweenness_centrality.get(node_id, 0.0) * 0.6
164 )
165 node.hub_score = hub_scores.get(node_id, 0.0)
166 node.authority_score = authority_scores.get(node_id, 0.0)
168 logger.info(f"Calculated centrality scores for {len(self.nodes)} nodes")
170 except Exception as e:
171 logger.error(f"Failed to calculate centrality scores: {e}")
173 def get_neighbors(
174 self, node_id: str, relationship_types: list[RelationshipType] | None = None
175 ) -> list[tuple[str, GraphEdge]]:
176 """Get neighboring nodes with their connecting edges."""
177 neighbors = []
179 if node_id not in self.graph:
180 return neighbors
182 for neighbor_id in self.graph.neighbors(node_id):
183 # Get all edges between these nodes
184 edge_data = self.graph.get_edge_data(node_id, neighbor_id)
185 if edge_data:
186 for _key, data in edge_data.items():
187 rel_value = data.get("relationship")
188 if rel_value is None:
189 logger.debug(
190 "Skipping edge with missing relationship metadata: %s -> %s",
191 node_id,
192 neighbor_id,
193 )
194 continue
195 try:
196 relationship = RelationshipType(rel_value)
197 except ValueError:
198 logger.warning(
199 "Skipping edge with invalid relationship value '%s': %s -> %s",
200 rel_value,
201 node_id,
202 neighbor_id,
203 )
204 continue
205 if relationship_types is None or relationship in relationship_types:
206 edge_key = (node_id, neighbor_id, relationship.value)
207 if edge_key in self.edges:
208 neighbors.append((neighbor_id, self.edges[edge_key]))
210 return neighbors
212 def get_statistics(self) -> dict[str, Any]:
213 """Get graph statistics."""
214 stats = {
215 "total_nodes": len(self.nodes),
216 "total_edges": len(self.edges),
217 "node_types": {
218 node_type.value: len(nodes)
219 for node_type, nodes in self.node_type_index.items()
220 },
221 "relationship_types": {},
222 "connected_components": nx.number_weakly_connected_components(self.graph),
223 "avg_degree": sum(dict(self.graph.degree()).values())
224 / max(len(self.graph.nodes), 1),
225 }
227 # Count relationship types
228 for edge in self.edges.values():
229 rel_type = edge.relationship_type.value
230 stats["relationship_types"][rel_type] = (
231 stats["relationship_types"].get(rel_type, 0) + 1
232 )
234 return stats