Coverage for src/qdrant_loader_mcp_server/search/enhanced/cdi/conflict_pairing.py: 58%

101 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1from __future__ import annotations 

2 

3import asyncio 

4from typing import Any 

5 

6import numpy as np 

7 

8 

9async def get_document_embeddings( 

10 detector: Any, document_ids: list[str] 

11) -> dict[str, list[float]]: 

12 """Retrieve document embeddings from Qdrant using detector settings. 

13 

14 This function mirrors ConflictDetector._get_document_embeddings and is extracted 

15 to reduce module size. It expects a detector with attributes: 

16 - qdrant_client, collection_name, preferred_vector_name, logger, _settings (optional) 

17 """ 

18 if not getattr(detector, "qdrant_client", None): 

19 return {} 

20 

21 qdrant_client = detector.qdrant_client 

22 

23 # Support mocked client in tests 

24 if hasattr(qdrant_client, "retrieve") and hasattr( 

25 qdrant_client.retrieve, "_mock_name" 

26 ): 

27 embeddings: dict[str, list[float]] = {} 

28 for doc_id in document_ids: 

29 try: 

30 points = await qdrant_client.retrieve( 

31 collection_name=detector.collection_name, 

32 ids=[doc_id], 

33 with_vectors=True, 

34 ) 

35 if points: 

36 point = points[0] 

37 if hasattr(point, "vector") and point.vector: 

38 embeddings[doc_id] = point.vector 

39 except Exception as e: # pragma: no cover (best-effort logging) 

40 detector.logger.warning( 

41 f"Failed to retrieve embedding for {doc_id}: {e}" 

42 ) 

43 return embeddings 

44 

45 try: 

46 embeddings: dict[str, list[float]] = {} 

47 settings = ( 

48 getattr(detector, "_settings", {}) if hasattr(detector, "_settings") else {} 

49 ) 

50 timeout_s = settings.get("conflict_embeddings_timeout_s", 5.0) 

51 max_cc = settings.get("conflict_embeddings_max_concurrency", 5) 

52 

53 semaphore = asyncio.Semaphore(max_cc) 

54 

55 async def fetch_embedding(doc_id: str) -> None: 

56 async with semaphore: 

57 try: 

58 search_result = await asyncio.wait_for( 

59 qdrant_client.scroll( 

60 collection_name=detector.collection_name, 

61 scroll_filter={ 

62 "must": [ 

63 { 

64 "key": "document_id", 

65 "match": {"value": doc_id}, 

66 } 

67 ] 

68 }, 

69 limit=1, 

70 with_vectors=True, 

71 ), 

72 timeout=timeout_s, 

73 ) 

74 

75 if search_result and search_result[0]: 

76 point = search_result[0][0] 

77 if point.vector: 

78 if isinstance(point.vector, dict): 

79 vector_data = ( 

80 point.vector.get(detector.preferred_vector_name) 

81 or point.vector.get("dense") 

82 or next(iter(point.vector.values()), None) 

83 ) 

84 else: 

85 vector_data = point.vector 

86 

87 if vector_data: 

88 embeddings[doc_id] = vector_data 

89 else: 

90 detector.logger.warning( 

91 f"No vector data found for document {doc_id}" 

92 ) 

93 else: 

94 detector.logger.warning( 

95 f"No vectors found for document {doc_id}" 

96 ) 

97 except TimeoutError: 

98 detector.logger.warning( 

99 f"Timeout retrieving embedding for document {doc_id}" 

100 ) 

101 except Exception as e: # pragma: no cover 

102 detector.logger.error( 

103 f"Error retrieving embedding for document {doc_id}: {e}" 

104 ) 

105 

106 await asyncio.gather( 

107 *(fetch_embedding(doc_id) for doc_id in document_ids), 

108 return_exceptions=True, 

109 ) 

110 return embeddings 

111 except Exception as e: # pragma: no cover 

112 detector.logger.error(f"Error retrieving document embeddings: {e}") 

113 return {} 

114 

115 

116def calculate_vector_similarity( 

117 _detector: Any, embedding1: list[float], embedding2: list[float] 

118) -> float: 

119 """Cosine similarity with clipping to [-1, 1].""" 

120 try: 

121 vec1 = np.array(embedding1) 

122 vec2 = np.array(embedding2) 

123 

124 dot_product = np.dot(vec1, vec2) 

125 norm1 = np.linalg.norm(vec1) 

126 norm2 = np.linalg.norm(vec2) 

127 if norm1 == 0 or norm2 == 0: 

128 return 0.0 

129 similarity = dot_product / (norm1 * norm2) 

130 return float(np.clip(similarity, -1.0, 1.0)) 

131 except Exception: # pragma: no cover 

132 return 0.0 

133 

134 

135async def filter_by_vector_similarity( 

136 detector: Any, documents: list[Any] 

137) -> list[tuple]: 

138 """Filter document pairs by vector similarity within configured band.""" 

139 similar_pairs: list[tuple] = [] 

140 if len(documents) < 2: 

141 return similar_pairs 

142 

143 document_ids = [ 

144 getattr(doc, "document_id", f"{doc.source_type}:{doc.source_title}") 

145 for doc in documents 

146 ] 

147 embeddings = await get_document_embeddings(detector, document_ids) 

148 

149 for i, doc1 in enumerate(documents): 

150 for j, doc2 in enumerate(documents[i + 1 :], i + 1): 

151 doc1_id = document_ids[i] 

152 doc2_id = document_ids[j] 

153 similarity_score = 0.0 

154 if doc1_id in embeddings and doc2_id in embeddings: 

155 similarity_score = calculate_vector_similarity( 

156 detector, embeddings[doc1_id], embeddings[doc2_id] 

157 ) 

158 if ( 

159 detector.MIN_VECTOR_SIMILARITY 

160 <= similarity_score 

161 <= detector.MAX_VECTOR_SIMILARITY 

162 ): 

163 similar_pairs.append((doc1, doc2, similarity_score)) 

164 

165 similar_pairs.sort(key=lambda x: x[2], reverse=True) 

166 return similar_pairs 

167 

168 

169def should_analyze_for_conflicts(_detector: Any, doc1: Any, doc2: Any) -> bool: 

170 """Pre-screen documents before conflict analysis.""" 

171 if not doc1 or not doc2: 

172 return False 

173 text1 = doc1.text if getattr(doc1, "text", None) else "" 

174 text2 = doc2.text if getattr(doc2, "text", None) else "" 

175 if len(text1.strip()) < 10 or len(text2.strip()) < 10: 

176 return False 

177 if hasattr(doc1, "document_id") and hasattr(doc2, "document_id"): 

178 if doc1.document_id == doc2.document_id: 

179 return False 

180 if text1.strip() == text2.strip(): 

181 return False 

182 return True 

183 

184 

185async def get_tiered_analysis_pairs(detector: Any, documents: list[Any]) -> list[tuple]: 

186 """Generate tiered analysis pairs for conflict detection (extracted).""" 

187 pairs: list[tuple] = [] 

188 if len(documents) < 2: 

189 return pairs 

190 

191 for i, doc1 in enumerate(documents): 

192 for _j, doc2 in enumerate(documents[i + 1 :], i + 1): 

193 score = 1.0 

194 if hasattr(doc1, "score") and hasattr(doc2, "score"): 

195 avg_doc_score = (doc1.score + doc2.score) / 2 

196 score = min(1.0, avg_doc_score) 

197 

198 if score >= 0.8: 

199 tier = "primary" 

200 elif score >= 0.5: 

201 tier = "secondary" 

202 else: 

203 tier = "tertiary" 

204 

205 pairs.append((doc1, doc2, tier, score)) 

206 

207 pairs.sort(key=lambda x: x[3], reverse=True) 

208 return pairs