Coverage for src/qdrant_loader_mcp_server/search/enhanced/cdi/models.py: 99%

126 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass, field 

4from datetime import datetime 

5from enum import Enum 

6from typing import Any 

7 

8import networkx as nx 

9from networkx.exception import NetworkXError, PowerIterationFailedConvergence 

10 

11from ....utils.logging import LoggingConfig 

12 

13logger = LoggingConfig.get_logger(__name__) 

14 

15 

16class SimilarityMetric(Enum): 

17 """Types of similarity metrics for document comparison.""" 

18 

19 ENTITY_OVERLAP = "entity_overlap" 

20 TOPIC_OVERLAP = "topic_overlap" 

21 SEMANTIC_SIMILARITY = "semantic_similarity" 

22 METADATA_SIMILARITY = "metadata_similarity" 

23 HIERARCHICAL_DISTANCE = "hierarchical_distance" 

24 CONTENT_FEATURES = "content_features" 

25 COMBINED = "combined" 

26 

27 

28class ClusteringStrategy(Enum): 

29 """Strategies for document clustering.""" 

30 

31 ENTITY_BASED = "entity_based" 

32 TOPIC_BASED = "topic_based" 

33 PROJECT_BASED = "project_based" 

34 HIERARCHICAL = "hierarchical" 

35 MIXED_FEATURES = "mixed_features" 

36 SEMANTIC_EMBEDDING = "semantic_embedding" 

37 

38 

39class RelationshipType(Enum): 

40 """Types of relationships between documents.""" 

41 

42 HIERARCHICAL = "hierarchical" 

43 CROSS_REFERENCE = "cross_reference" 

44 SEMANTIC_SIMILARITY = "semantic_similarity" 

45 COMPLEMENTARY = "complementary" 

46 CONFLICTING = "conflicting" 

47 SEQUENTIAL = "sequential" 

48 TOPICAL_GROUPING = "topical_grouping" 

49 PROJECT_GROUPING = "project_grouping" 

50 

51 

52@dataclass 

53class DocumentSimilarity: 

54 """Represents similarity between two documents.""" 

55 

56 doc1_id: str 

57 doc2_id: str 

58 similarity_score: float # 0.0 - 1.0 

59 metric_scores: dict[SimilarityMetric, float] = field(default_factory=dict) 

60 shared_entities: list[str] = field(default_factory=list) 

61 shared_topics: list[str] = field(default_factory=list) 

62 relationship_type: RelationshipType = RelationshipType.SEMANTIC_SIMILARITY 

63 explanation: str = "" 

64 

65 def get_display_explanation(self) -> str: 

66 """Get human-readable explanation of similarity.""" 

67 if self.explanation: 

68 return self.explanation 

69 

70 explanations: list[str] = [] 

71 if self.shared_entities: 

72 explanations.append( 

73 f"Shared entities: {', '.join(self.shared_entities[:3])}" 

74 ) 

75 if self.shared_topics: 

76 explanations.append(f"Shared topics: {', '.join(self.shared_topics[:3])}") 

77 if self.metric_scores: 

78 top_metric = max(self.metric_scores.items(), key=lambda x: x[1]) 

79 explanations.append(f"High {top_metric[0].value}: {top_metric[1]:.2f}") 

80 

81 return "; ".join(explanations) if explanations else "Semantic similarity" 

82 

83 

84@dataclass 

85class DocumentCluster: 

86 """Represents a cluster of related documents.""" 

87 

88 cluster_id: str 

89 name: str 

90 documents: list[str] = field(default_factory=list) # Document IDs 

91 shared_entities: list[str] = field(default_factory=list) 

92 shared_topics: list[str] = field(default_factory=list) 

93 cluster_strategy: ClusteringStrategy = ClusteringStrategy.MIXED_FEATURES 

94 coherence_score: float = 0.0 # 0.0 - 1.0 

95 representative_doc_id: str = "" 

96 cluster_description: str = "" 

97 

98 def get_cluster_summary(self) -> dict[str, Any]: 

99 """Get summary information about the cluster.""" 

100 return { 

101 "cluster_id": self.cluster_id, 

102 "name": self.name, 

103 "document_count": len(self.documents), 

104 "coherence_score": self.coherence_score, 

105 "primary_entities": self.shared_entities[:5], 

106 "primary_topics": self.shared_topics[:5], 

107 "strategy": self.cluster_strategy.value, 

108 "description": self.cluster_description, 

109 } 

110 

111 

112@dataclass 

113class CitationNetwork: 

114 """Represents a citation/reference network between documents.""" 

115 

116 nodes: dict[str, dict[str, Any]] = field(default_factory=dict) # doc_id -> metadata 

117 edges: list[tuple[str, str, dict[str, Any]]] = field( 

118 default_factory=list 

119 ) # (from, to, metadata) 

120 graph: nx.DiGraph | None = None 

121 authority_scores: dict[str, float] = field(default_factory=dict) 

122 hub_scores: dict[str, float] = field(default_factory=dict) 

123 pagerank_scores: dict[str, float] = field(default_factory=dict) 

124 

125 def build_graph(self) -> nx.DiGraph: 

126 """Build NetworkX graph from nodes and edges.""" 

127 if self.graph is None: 

128 self.graph = nx.DiGraph() 

129 

130 for doc_id, metadata in self.nodes.items(): 

131 self.graph.add_node(doc_id, **metadata) 

132 

133 for from_doc, to_doc, edge_metadata in self.edges: 

134 self.graph.add_edge(from_doc, to_doc, **edge_metadata) 

135 

136 return self.graph 

137 

138 def calculate_centrality_scores(self): 

139 """Calculate various centrality scores for the citation network.""" 

140 if self.graph is None: 

141 self.build_graph() 

142 

143 try: 

144 if self.graph.number_of_edges() == 0: 

145 if self.graph.nodes(): 

146 degree_centrality = nx.degree_centrality(self.graph) 

147 self.authority_scores = degree_centrality 

148 self.hub_scores = degree_centrality 

149 self.pagerank_scores = degree_centrality 

150 return 

151 

152 hits_scores = nx.hits(self.graph, max_iter=100, normalized=True) 

153 self.hub_scores = hits_scores[0] 

154 self.authority_scores = hits_scores[1] 

155 

156 self.pagerank_scores = nx.pagerank(self.graph, max_iter=100) 

157 

158 except (NetworkXError, PowerIterationFailedConvergence, ValueError): 

159 logger.exception( 

160 "Centrality computation failed; falling back to degree centrality" 

161 ) 

162 if self.graph.nodes(): 

163 degree_centrality = nx.degree_centrality(self.graph) 

164 self.authority_scores = degree_centrality 

165 self.hub_scores = degree_centrality 

166 self.pagerank_scores = degree_centrality 

167 

168 

169@dataclass 

170class ComplementaryContent: 

171 """Represents complementary content recommendations.""" 

172 

173 target_doc_id: str 

174 recommendations: list[tuple[str, float, str]] = field( 

175 default_factory=list 

176 ) # (doc_id, score, reason) 

177 recommendation_strategy: str = "mixed" 

178 generated_at: datetime = field(default_factory=datetime.now) 

179 

180 def get_top_recommendations(self, limit: int = 5) -> list[dict[str, Any]]: 

181 """Get top N recommendations with detailed information.""" 

182 # Validate input limit explicitly to avoid silent misuse 

183 if not isinstance(limit, int) or limit <= 0: 

184 raise ValueError("limit must be an int greater than 0") 

185 

186 top_recs = sorted(self.recommendations, key=lambda x: x[1], reverse=True)[ 

187 :limit 

188 ] 

189 return [ 

190 { 

191 "document_id": doc_id, 

192 "relevance_score": score, 

193 "recommendation_reason": reason, 

194 "strategy": self.recommendation_strategy, 

195 } 

196 for doc_id, score, reason in top_recs 

197 ] 

198 

199 

200@dataclass 

201class ConflictAnalysis: 

202 """Represents analysis of conflicting information between documents.""" 

203 

204 conflicting_pairs: list[tuple[str, str, dict[str, Any]]] = field( 

205 default_factory=list 

206 ) # (doc1, doc2, conflict_info) 

207 conflict_categories: dict[str, list[tuple[str, str]]] = field(default_factory=dict) 

208 resolution_suggestions: dict[str, str] = field(default_factory=dict) 

209 

210 def get_conflict_summary(self) -> dict[str, Any]: 

211 """Get summary of detected conflicts.""" 

212 return { 

213 "total_conflicts": len(self.conflicting_pairs), 

214 "conflict_categories": { 

215 cat: len(pairs) for cat, pairs in self.conflict_categories.items() 

216 }, 

217 "most_common_conflicts": self._get_most_common_conflicts(), 

218 "resolution_suggestions": list(self.resolution_suggestions.values())[:3], 

219 } 

220 

221 def _get_most_common_conflicts(self) -> list[str]: 

222 """Get the most common types of conflicts.""" 

223 return sorted( 

224 self.conflict_categories.keys(), 

225 key=lambda x: len(self.conflict_categories[x]), 

226 reverse=True, 

227 )[:3]