Coverage for src / qdrant_loader_mcp_server / search / enhanced / cdi / models.py: 99%
126 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
1from __future__ import annotations
3from dataclasses import dataclass, field
4from datetime import datetime
5from enum import Enum
6from typing import Any
8# Note: networkx is imported at module level because:
9# 1. This module (cdi/models) is only imported when CDI functionality is needed
10# 2. Tests need to be able to patch nx.hits, nx.pagerank, etc.
11# The lazy loading happens at a higher level - this module is not imported at MCP startup
12import networkx as nx
13from networkx.exception import NetworkXError, PowerIterationFailedConvergence
15from ....utils.logging import LoggingConfig
17logger = LoggingConfig.get_logger(__name__)
20class SimilarityMetric(Enum):
21 """Types of similarity metrics for document comparison."""
23 ENTITY_OVERLAP = "entity_overlap"
24 TOPIC_OVERLAP = "topic_overlap"
25 SEMANTIC_SIMILARITY = "semantic_similarity"
26 METADATA_SIMILARITY = "metadata_similarity"
27 HIERARCHICAL_DISTANCE = "hierarchical_distance"
28 CONTENT_FEATURES = "content_features"
29 COMBINED = "combined"
32class ClusteringStrategy(Enum):
33 """Strategies for document clustering."""
35 ENTITY_BASED = "entity_based"
36 TOPIC_BASED = "topic_based"
37 PROJECT_BASED = "project_based"
38 HIERARCHICAL = "hierarchical"
39 MIXED_FEATURES = "mixed_features"
40 SEMANTIC_EMBEDDING = "semantic_embedding"
43class RelationshipType(Enum):
44 """Types of relationships between documents."""
46 HIERARCHICAL = "hierarchical"
47 CROSS_REFERENCE = "cross_reference"
48 SEMANTIC_SIMILARITY = "semantic_similarity"
49 COMPLEMENTARY = "complementary"
50 CONFLICTING = "conflicting"
51 SEQUENTIAL = "sequential"
52 TOPICAL_GROUPING = "topical_grouping"
53 PROJECT_GROUPING = "project_grouping"
56@dataclass
57class DocumentSimilarity:
58 """Represents similarity between two documents."""
60 doc1_id: str
61 doc2_id: str
62 similarity_score: float # 0.0 - 1.0
63 metric_scores: dict[SimilarityMetric, float] = field(default_factory=dict)
64 shared_entities: list[str] = field(default_factory=list)
65 shared_topics: list[str] = field(default_factory=list)
66 relationship_type: RelationshipType = RelationshipType.SEMANTIC_SIMILARITY
67 explanation: str = ""
69 def get_display_explanation(self) -> str:
70 """Get human-readable explanation of similarity."""
71 if self.explanation:
72 return self.explanation
74 explanations: list[str] = []
75 if self.shared_entities:
76 explanations.append(
77 f"Shared entities: {', '.join(self.shared_entities[:3])}"
78 )
79 if self.shared_topics:
80 explanations.append(f"Shared topics: {', '.join(self.shared_topics[:3])}")
81 if self.metric_scores:
82 top_metric = max(self.metric_scores.items(), key=lambda x: x[1])
83 explanations.append(f"High {top_metric[0].value}: {top_metric[1]:.2f}")
85 return "; ".join(explanations) if explanations else "Semantic similarity"
88@dataclass
89class DocumentCluster:
90 """Represents a cluster of related documents."""
92 cluster_id: str
93 name: str
94 documents: list[str] = field(default_factory=list) # Document IDs
95 shared_entities: list[str] = field(default_factory=list)
96 shared_topics: list[str] = field(default_factory=list)
97 cluster_strategy: ClusteringStrategy = ClusteringStrategy.MIXED_FEATURES
98 coherence_score: float = 0.0 # 0.0 - 1.0
99 representative_doc_id: str = ""
100 cluster_description: str = ""
102 def get_cluster_summary(self) -> dict[str, Any]:
103 """Get summary information about the cluster."""
104 return {
105 "cluster_id": self.cluster_id,
106 "name": self.name,
107 "document_count": len(self.documents),
108 "coherence_score": self.coherence_score,
109 "primary_entities": self.shared_entities[:5],
110 "primary_topics": self.shared_topics[:5],
111 "strategy": self.cluster_strategy.value,
112 "description": self.cluster_description,
113 }
116@dataclass
117class CitationNetwork:
118 """Represents a citation/reference network between documents."""
120 nodes: dict[str, dict[str, Any]] = field(default_factory=dict) # doc_id -> metadata
121 edges: list[tuple[str, str, dict[str, Any]]] = field(
122 default_factory=list
123 ) # (from, to, metadata)
124 graph: nx.DiGraph | None = None
125 authority_scores: dict[str, float] = field(default_factory=dict)
126 hub_scores: dict[str, float] = field(default_factory=dict)
127 pagerank_scores: dict[str, float] = field(default_factory=dict)
129 def build_graph(self) -> nx.DiGraph:
130 """Build NetworkX graph from nodes and edges."""
131 if self.graph is None:
132 self.graph = nx.DiGraph()
134 for doc_id, metadata in self.nodes.items():
135 self.graph.add_node(doc_id, **metadata)
137 for from_doc, to_doc, edge_metadata in self.edges:
138 self.graph.add_edge(from_doc, to_doc, **edge_metadata)
140 return self.graph
142 def calculate_centrality_scores(self):
143 """Calculate various centrality scores for the citation network."""
144 if self.graph is None:
145 self.build_graph()
147 try:
148 if self.graph.number_of_edges() == 0:
149 if self.graph.nodes():
150 degree_centrality = nx.degree_centrality(self.graph)
151 self.authority_scores = degree_centrality
152 self.hub_scores = degree_centrality
153 self.pagerank_scores = degree_centrality
154 return
156 hits_scores = nx.hits(self.graph, max_iter=100, normalized=True)
157 self.hub_scores = hits_scores[0]
158 self.authority_scores = hits_scores[1]
160 self.pagerank_scores = nx.pagerank(self.graph, max_iter=100)
162 except (NetworkXError, PowerIterationFailedConvergence, ValueError):
163 logger.exception(
164 "Centrality computation failed; falling back to degree centrality"
165 )
166 if self.graph.nodes():
167 degree_centrality = nx.degree_centrality(self.graph)
168 self.authority_scores = degree_centrality
169 self.hub_scores = degree_centrality
170 self.pagerank_scores = degree_centrality
173@dataclass
174class ComplementaryContent:
175 """Represents complementary content recommendations."""
177 target_doc_id: str
178 recommendations: list[tuple[str, float, str]] = field(
179 default_factory=list
180 ) # (doc_id, score, reason)
181 recommendation_strategy: str = "mixed"
182 generated_at: datetime = field(default_factory=datetime.now)
184 def get_top_recommendations(self, limit: int = 5) -> list[dict[str, Any]]:
185 """Get top N recommendations with detailed information."""
186 # Validate input limit explicitly to avoid silent misuse
187 if not isinstance(limit, int) or limit <= 0:
188 raise ValueError("limit must be an int greater than 0")
190 top_recs = sorted(self.recommendations, key=lambda x: x[1], reverse=True)[
191 :limit
192 ]
193 return [
194 {
195 "document_id": doc_id,
196 "relevance_score": score,
197 "recommendation_reason": reason,
198 "strategy": self.recommendation_strategy,
199 }
200 for doc_id, score, reason in top_recs
201 ]
204@dataclass
205class ConflictAnalysis:
206 """Represents analysis of conflicting information between documents."""
208 conflicting_pairs: list[tuple[str, str, dict[str, Any]]] = field(
209 default_factory=list
210 ) # (doc1, doc2, conflict_info)
211 conflict_categories: dict[str, list[tuple[str, str]]] = field(default_factory=dict)
212 resolution_suggestions: dict[str, str] = field(default_factory=dict)
214 def get_conflict_summary(self) -> dict[str, Any]:
215 """Get summary of detected conflicts."""
216 return {
217 "total_conflicts": len(self.conflicting_pairs),
218 "conflict_categories": {
219 cat: len(pairs) for cat, pairs in self.conflict_categories.items()
220 },
221 "most_common_conflicts": self._get_most_common_conflicts(),
222 "resolution_suggestions": list(self.resolution_suggestions.values())[:3],
223 }
225 def _get_most_common_conflicts(self) -> list[str]:
226 """Get the most common types of conflicts."""
227 return sorted(
228 self.conflict_categories.keys(),
229 key=lambda x: len(self.conflict_categories[x]),
230 reverse=True,
231 )[:3]