Coverage for src / qdrant_loader_mcp_server / search / enhanced / cdi / models.py: 99%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 04:51 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass, field 

4from datetime import datetime 

5from enum import Enum 

6from typing import Any 

7 

8# Note: networkx is imported at module level because: 

9# 1. This module (cdi/models) is only imported when CDI functionality is needed 

10# 2. Tests need to be able to patch nx.hits, nx.pagerank, etc. 

11# The lazy loading happens at a higher level - this module is not imported at MCP startup 

12import networkx as nx 

13from networkx.exception import NetworkXError, PowerIterationFailedConvergence 

14 

15from ....utils.logging import LoggingConfig 

16 

17logger = LoggingConfig.get_logger(__name__) 

18 

19 

20class SimilarityMetric(Enum): 

21 """Types of similarity metrics for document comparison.""" 

22 

23 ENTITY_OVERLAP = "entity_overlap" 

24 TOPIC_OVERLAP = "topic_overlap" 

25 SEMANTIC_SIMILARITY = "semantic_similarity" 

26 METADATA_SIMILARITY = "metadata_similarity" 

27 HIERARCHICAL_DISTANCE = "hierarchical_distance" 

28 CONTENT_FEATURES = "content_features" 

29 COMBINED = "combined" 

30 

31 

32class ClusteringStrategy(Enum): 

33 """Strategies for document clustering.""" 

34 

35 ENTITY_BASED = "entity_based" 

36 TOPIC_BASED = "topic_based" 

37 PROJECT_BASED = "project_based" 

38 HIERARCHICAL = "hierarchical" 

39 MIXED_FEATURES = "mixed_features" 

40 SEMANTIC_EMBEDDING = "semantic_embedding" 

41 

42 

43class RelationshipType(Enum): 

44 """Types of relationships between documents.""" 

45 

46 HIERARCHICAL = "hierarchical" 

47 CROSS_REFERENCE = "cross_reference" 

48 SEMANTIC_SIMILARITY = "semantic_similarity" 

49 COMPLEMENTARY = "complementary" 

50 CONFLICTING = "conflicting" 

51 SEQUENTIAL = "sequential" 

52 TOPICAL_GROUPING = "topical_grouping" 

53 PROJECT_GROUPING = "project_grouping" 

54 

55 

56@dataclass 

57class DocumentSimilarity: 

58 """Represents similarity between two documents.""" 

59 

60 doc1_id: str 

61 doc2_id: str 

62 similarity_score: float # 0.0 - 1.0 

63 metric_scores: dict[SimilarityMetric, float] = field(default_factory=dict) 

64 shared_entities: list[str] = field(default_factory=list) 

65 shared_topics: list[str] = field(default_factory=list) 

66 relationship_type: RelationshipType = RelationshipType.SEMANTIC_SIMILARITY 

67 explanation: str = "" 

68 

69 def get_display_explanation(self) -> str: 

70 """Get human-readable explanation of similarity.""" 

71 if self.explanation: 

72 return self.explanation 

73 

74 explanations: list[str] = [] 

75 if self.shared_entities: 

76 explanations.append( 

77 f"Shared entities: {', '.join(self.shared_entities[:3])}" 

78 ) 

79 if self.shared_topics: 

80 explanations.append(f"Shared topics: {', '.join(self.shared_topics[:3])}") 

81 if self.metric_scores: 

82 top_metric = max(self.metric_scores.items(), key=lambda x: x[1]) 

83 explanations.append(f"High {top_metric[0].value}: {top_metric[1]:.2f}") 

84 

85 return "; ".join(explanations) if explanations else "Semantic similarity" 

86 

87 

88@dataclass 

89class DocumentCluster: 

90 """Represents a cluster of related documents.""" 

91 

92 cluster_id: str 

93 name: str 

94 documents: list[str] = field(default_factory=list) # Document IDs 

95 shared_entities: list[str] = field(default_factory=list) 

96 shared_topics: list[str] = field(default_factory=list) 

97 cluster_strategy: ClusteringStrategy = ClusteringStrategy.MIXED_FEATURES 

98 coherence_score: float = 0.0 # 0.0 - 1.0 

99 representative_doc_id: str = "" 

100 cluster_description: str = "" 

101 

102 def get_cluster_summary(self) -> dict[str, Any]: 

103 """Get summary information about the cluster.""" 

104 return { 

105 "cluster_id": self.cluster_id, 

106 "name": self.name, 

107 "document_count": len(self.documents), 

108 "coherence_score": self.coherence_score, 

109 "primary_entities": self.shared_entities[:5], 

110 "primary_topics": self.shared_topics[:5], 

111 "strategy": self.cluster_strategy.value, 

112 "description": self.cluster_description, 

113 } 

114 

115 

116@dataclass 

117class CitationNetwork: 

118 """Represents a citation/reference network between documents.""" 

119 

120 nodes: dict[str, dict[str, Any]] = field(default_factory=dict) # doc_id -> metadata 

121 edges: list[tuple[str, str, dict[str, Any]]] = field( 

122 default_factory=list 

123 ) # (from, to, metadata) 

124 graph: nx.DiGraph | None = None 

125 authority_scores: dict[str, float] = field(default_factory=dict) 

126 hub_scores: dict[str, float] = field(default_factory=dict) 

127 pagerank_scores: dict[str, float] = field(default_factory=dict) 

128 

129 def build_graph(self) -> nx.DiGraph: 

130 """Build NetworkX graph from nodes and edges.""" 

131 if self.graph is None: 

132 self.graph = nx.DiGraph() 

133 

134 for doc_id, metadata in self.nodes.items(): 

135 self.graph.add_node(doc_id, **metadata) 

136 

137 for from_doc, to_doc, edge_metadata in self.edges: 

138 self.graph.add_edge(from_doc, to_doc, **edge_metadata) 

139 

140 return self.graph 

141 

142 def calculate_centrality_scores(self): 

143 """Calculate various centrality scores for the citation network.""" 

144 if self.graph is None: 

145 self.build_graph() 

146 

147 try: 

148 if self.graph.number_of_edges() == 0: 

149 if self.graph.nodes(): 

150 degree_centrality = nx.degree_centrality(self.graph) 

151 self.authority_scores = degree_centrality 

152 self.hub_scores = degree_centrality 

153 self.pagerank_scores = degree_centrality 

154 return 

155 

156 hits_scores = nx.hits(self.graph, max_iter=100, normalized=True) 

157 self.hub_scores = hits_scores[0] 

158 self.authority_scores = hits_scores[1] 

159 

160 self.pagerank_scores = nx.pagerank(self.graph, max_iter=100) 

161 

162 except (NetworkXError, PowerIterationFailedConvergence, ValueError): 

163 logger.exception( 

164 "Centrality computation failed; falling back to degree centrality" 

165 ) 

166 if self.graph.nodes(): 

167 degree_centrality = nx.degree_centrality(self.graph) 

168 self.authority_scores = degree_centrality 

169 self.hub_scores = degree_centrality 

170 self.pagerank_scores = degree_centrality 

171 

172 

173@dataclass 

174class ComplementaryContent: 

175 """Represents complementary content recommendations.""" 

176 

177 target_doc_id: str 

178 recommendations: list[tuple[str, float, str]] = field( 

179 default_factory=list 

180 ) # (doc_id, score, reason) 

181 recommendation_strategy: str = "mixed" 

182 generated_at: datetime = field(default_factory=datetime.now) 

183 

184 def get_top_recommendations(self, limit: int = 5) -> list[dict[str, Any]]: 

185 """Get top N recommendations with detailed information.""" 

186 # Validate input limit explicitly to avoid silent misuse 

187 if not isinstance(limit, int) or limit <= 0: 

188 raise ValueError("limit must be an int greater than 0") 

189 

190 top_recs = sorted(self.recommendations, key=lambda x: x[1], reverse=True)[ 

191 :limit 

192 ] 

193 return [ 

194 { 

195 "document_id": doc_id, 

196 "relevance_score": score, 

197 "recommendation_reason": reason, 

198 "strategy": self.recommendation_strategy, 

199 } 

200 for doc_id, score, reason in top_recs 

201 ] 

202 

203 

204@dataclass 

205class ConflictAnalysis: 

206 """Represents analysis of conflicting information between documents.""" 

207 

208 conflicting_pairs: list[tuple[str, str, dict[str, Any]]] = field( 

209 default_factory=list 

210 ) # (doc1, doc2, conflict_info) 

211 conflict_categories: dict[str, list[tuple[str, str]]] = field(default_factory=dict) 

212 resolution_suggestions: dict[str, str] = field(default_factory=dict) 

213 

214 def get_conflict_summary(self) -> dict[str, Any]: 

215 """Get summary of detected conflicts.""" 

216 return { 

217 "total_conflicts": len(self.conflicting_pairs), 

218 "conflict_categories": { 

219 cat: len(pairs) for cat, pairs in self.conflict_categories.items() 

220 }, 

221 "most_common_conflicts": self._get_most_common_conflicts(), 

222 "resolution_suggestions": list(self.resolution_suggestions.values())[:3], 

223 } 

224 

225 def _get_most_common_conflicts(self) -> list[str]: 

226 """Get the most common types of conflicts.""" 

227 return sorted( 

228 self.conflict_categories.keys(), 

229 key=lambda x: len(self.conflict_categories[x]), 

230 reverse=True, 

231 )[:3]