Coverage for src/qdrant_loader_mcp_server/search/enhanced/cdi/models.py: 99%
126 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
1from __future__ import annotations
3from dataclasses import dataclass, field
4from datetime import datetime
5from enum import Enum
6from typing import Any
8import networkx as nx
9from networkx.exception import NetworkXError, PowerIterationFailedConvergence
11from ....utils.logging import LoggingConfig
13logger = LoggingConfig.get_logger(__name__)
16class SimilarityMetric(Enum):
17 """Types of similarity metrics for document comparison."""
19 ENTITY_OVERLAP = "entity_overlap"
20 TOPIC_OVERLAP = "topic_overlap"
21 SEMANTIC_SIMILARITY = "semantic_similarity"
22 METADATA_SIMILARITY = "metadata_similarity"
23 HIERARCHICAL_DISTANCE = "hierarchical_distance"
24 CONTENT_FEATURES = "content_features"
25 COMBINED = "combined"
28class ClusteringStrategy(Enum):
29 """Strategies for document clustering."""
31 ENTITY_BASED = "entity_based"
32 TOPIC_BASED = "topic_based"
33 PROJECT_BASED = "project_based"
34 HIERARCHICAL = "hierarchical"
35 MIXED_FEATURES = "mixed_features"
36 SEMANTIC_EMBEDDING = "semantic_embedding"
39class RelationshipType(Enum):
40 """Types of relationships between documents."""
42 HIERARCHICAL = "hierarchical"
43 CROSS_REFERENCE = "cross_reference"
44 SEMANTIC_SIMILARITY = "semantic_similarity"
45 COMPLEMENTARY = "complementary"
46 CONFLICTING = "conflicting"
47 SEQUENTIAL = "sequential"
48 TOPICAL_GROUPING = "topical_grouping"
49 PROJECT_GROUPING = "project_grouping"
52@dataclass
53class DocumentSimilarity:
54 """Represents similarity between two documents."""
56 doc1_id: str
57 doc2_id: str
58 similarity_score: float # 0.0 - 1.0
59 metric_scores: dict[SimilarityMetric, float] = field(default_factory=dict)
60 shared_entities: list[str] = field(default_factory=list)
61 shared_topics: list[str] = field(default_factory=list)
62 relationship_type: RelationshipType = RelationshipType.SEMANTIC_SIMILARITY
63 explanation: str = ""
65 def get_display_explanation(self) -> str:
66 """Get human-readable explanation of similarity."""
67 if self.explanation:
68 return self.explanation
70 explanations: list[str] = []
71 if self.shared_entities:
72 explanations.append(
73 f"Shared entities: {', '.join(self.shared_entities[:3])}"
74 )
75 if self.shared_topics:
76 explanations.append(f"Shared topics: {', '.join(self.shared_topics[:3])}")
77 if self.metric_scores:
78 top_metric = max(self.metric_scores.items(), key=lambda x: x[1])
79 explanations.append(f"High {top_metric[0].value}: {top_metric[1]:.2f}")
81 return "; ".join(explanations) if explanations else "Semantic similarity"
84@dataclass
85class DocumentCluster:
86 """Represents a cluster of related documents."""
88 cluster_id: str
89 name: str
90 documents: list[str] = field(default_factory=list) # Document IDs
91 shared_entities: list[str] = field(default_factory=list)
92 shared_topics: list[str] = field(default_factory=list)
93 cluster_strategy: ClusteringStrategy = ClusteringStrategy.MIXED_FEATURES
94 coherence_score: float = 0.0 # 0.0 - 1.0
95 representative_doc_id: str = ""
96 cluster_description: str = ""
98 def get_cluster_summary(self) -> dict[str, Any]:
99 """Get summary information about the cluster."""
100 return {
101 "cluster_id": self.cluster_id,
102 "name": self.name,
103 "document_count": len(self.documents),
104 "coherence_score": self.coherence_score,
105 "primary_entities": self.shared_entities[:5],
106 "primary_topics": self.shared_topics[:5],
107 "strategy": self.cluster_strategy.value,
108 "description": self.cluster_description,
109 }
112@dataclass
113class CitationNetwork:
114 """Represents a citation/reference network between documents."""
116 nodes: dict[str, dict[str, Any]] = field(default_factory=dict) # doc_id -> metadata
117 edges: list[tuple[str, str, dict[str, Any]]] = field(
118 default_factory=list
119 ) # (from, to, metadata)
120 graph: nx.DiGraph | None = None
121 authority_scores: dict[str, float] = field(default_factory=dict)
122 hub_scores: dict[str, float] = field(default_factory=dict)
123 pagerank_scores: dict[str, float] = field(default_factory=dict)
125 def build_graph(self) -> nx.DiGraph:
126 """Build NetworkX graph from nodes and edges."""
127 if self.graph is None:
128 self.graph = nx.DiGraph()
130 for doc_id, metadata in self.nodes.items():
131 self.graph.add_node(doc_id, **metadata)
133 for from_doc, to_doc, edge_metadata in self.edges:
134 self.graph.add_edge(from_doc, to_doc, **edge_metadata)
136 return self.graph
138 def calculate_centrality_scores(self):
139 """Calculate various centrality scores for the citation network."""
140 if self.graph is None:
141 self.build_graph()
143 try:
144 if self.graph.number_of_edges() == 0:
145 if self.graph.nodes():
146 degree_centrality = nx.degree_centrality(self.graph)
147 self.authority_scores = degree_centrality
148 self.hub_scores = degree_centrality
149 self.pagerank_scores = degree_centrality
150 return
152 hits_scores = nx.hits(self.graph, max_iter=100, normalized=True)
153 self.hub_scores = hits_scores[0]
154 self.authority_scores = hits_scores[1]
156 self.pagerank_scores = nx.pagerank(self.graph, max_iter=100)
158 except (NetworkXError, PowerIterationFailedConvergence, ValueError):
159 logger.exception(
160 "Centrality computation failed; falling back to degree centrality"
161 )
162 if self.graph.nodes():
163 degree_centrality = nx.degree_centrality(self.graph)
164 self.authority_scores = degree_centrality
165 self.hub_scores = degree_centrality
166 self.pagerank_scores = degree_centrality
169@dataclass
170class ComplementaryContent:
171 """Represents complementary content recommendations."""
173 target_doc_id: str
174 recommendations: list[tuple[str, float, str]] = field(
175 default_factory=list
176 ) # (doc_id, score, reason)
177 recommendation_strategy: str = "mixed"
178 generated_at: datetime = field(default_factory=datetime.now)
180 def get_top_recommendations(self, limit: int = 5) -> list[dict[str, Any]]:
181 """Get top N recommendations with detailed information."""
182 # Validate input limit explicitly to avoid silent misuse
183 if not isinstance(limit, int) or limit <= 0:
184 raise ValueError("limit must be an int greater than 0")
186 top_recs = sorted(self.recommendations, key=lambda x: x[1], reverse=True)[
187 :limit
188 ]
189 return [
190 {
191 "document_id": doc_id,
192 "relevance_score": score,
193 "recommendation_reason": reason,
194 "strategy": self.recommendation_strategy,
195 }
196 for doc_id, score, reason in top_recs
197 ]
200@dataclass
201class ConflictAnalysis:
202 """Represents analysis of conflicting information between documents."""
204 conflicting_pairs: list[tuple[str, str, dict[str, Any]]] = field(
205 default_factory=list
206 ) # (doc1, doc2, conflict_info)
207 conflict_categories: dict[str, list[tuple[str, str]]] = field(default_factory=dict)
208 resolution_suggestions: dict[str, str] = field(default_factory=dict)
210 def get_conflict_summary(self) -> dict[str, Any]:
211 """Get summary of detected conflicts."""
212 return {
213 "total_conflicts": len(self.conflicting_pairs),
214 "conflict_categories": {
215 cat: len(pairs) for cat, pairs in self.conflict_categories.items()
216 },
217 "most_common_conflicts": self._get_most_common_conflicts(),
218 "resolution_suggestions": list(self.resolution_suggestions.values())[:3],
219 }
221 def _get_most_common_conflicts(self) -> list[str]:
222 """Get the most common types of conflicts."""
223 return sorted(
224 self.conflict_categories.keys(),
225 key=lambda x: len(self.conflict_categories[x]),
226 reverse=True,
227 )[:3]