Coverage for src/qdrant_loader_mcp_server/search/enhanced/cdi/conflict_pairing.py: 58%
101 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:06 +0000
1from __future__ import annotations
3import asyncio
4from typing import Any
6import numpy as np
9async def get_document_embeddings(
10 detector: Any, document_ids: list[str]
11) -> dict[str, list[float]]:
12 """Retrieve document embeddings from Qdrant using detector settings.
14 This function mirrors ConflictDetector._get_document_embeddings and is extracted
15 to reduce module size. It expects a detector with attributes:
16 - qdrant_client, collection_name, preferred_vector_name, logger, _settings (optional)
17 """
18 if not getattr(detector, "qdrant_client", None):
19 return {}
21 qdrant_client = detector.qdrant_client
23 # Support mocked client in tests
24 if hasattr(qdrant_client, "retrieve") and hasattr(
25 qdrant_client.retrieve, "_mock_name"
26 ):
27 embeddings: dict[str, list[float]] = {}
28 for doc_id in document_ids:
29 try:
30 points = await qdrant_client.retrieve(
31 collection_name=detector.collection_name,
32 ids=[doc_id],
33 with_vectors=True,
34 )
35 if points:
36 point = points[0]
37 if hasattr(point, "vector") and point.vector:
38 embeddings[doc_id] = point.vector
39 except Exception as e: # pragma: no cover (best-effort logging)
40 detector.logger.warning(
41 f"Failed to retrieve embedding for {doc_id}: {e}"
42 )
43 return embeddings
45 try:
46 embeddings: dict[str, list[float]] = {}
47 settings = (
48 getattr(detector, "_settings", {}) if hasattr(detector, "_settings") else {}
49 )
50 timeout_s = settings.get("conflict_embeddings_timeout_s", 5.0)
51 max_cc = settings.get("conflict_embeddings_max_concurrency", 5)
53 semaphore = asyncio.Semaphore(max_cc)
55 async def fetch_embedding(doc_id: str) -> None:
56 async with semaphore:
57 try:
58 search_result = await asyncio.wait_for(
59 qdrant_client.scroll(
60 collection_name=detector.collection_name,
61 scroll_filter={
62 "must": [
63 {
64 "key": "document_id",
65 "match": {"value": doc_id},
66 }
67 ]
68 },
69 limit=1,
70 with_vectors=True,
71 ),
72 timeout=timeout_s,
73 )
75 if search_result and search_result[0]:
76 point = search_result[0][0]
77 if point.vector:
78 if isinstance(point.vector, dict):
79 vector_data = (
80 point.vector.get(detector.preferred_vector_name)
81 or point.vector.get("dense")
82 or next(iter(point.vector.values()), None)
83 )
84 else:
85 vector_data = point.vector
87 if vector_data:
88 embeddings[doc_id] = vector_data
89 else:
90 detector.logger.warning(
91 f"No vector data found for document {doc_id}"
92 )
93 else:
94 detector.logger.warning(
95 f"No vectors found for document {doc_id}"
96 )
97 except TimeoutError:
98 detector.logger.warning(
99 f"Timeout retrieving embedding for document {doc_id}"
100 )
101 except Exception as e: # pragma: no cover
102 detector.logger.error(
103 f"Error retrieving embedding for document {doc_id}: {e}"
104 )
106 await asyncio.gather(
107 *(fetch_embedding(doc_id) for doc_id in document_ids),
108 return_exceptions=True,
109 )
110 return embeddings
111 except Exception as e: # pragma: no cover
112 detector.logger.error(f"Error retrieving document embeddings: {e}")
113 return {}
116def calculate_vector_similarity(
117 _detector: Any, embedding1: list[float], embedding2: list[float]
118) -> float:
119 """Cosine similarity with clipping to [-1, 1]."""
120 try:
121 vec1 = np.array(embedding1)
122 vec2 = np.array(embedding2)
124 dot_product = np.dot(vec1, vec2)
125 norm1 = np.linalg.norm(vec1)
126 norm2 = np.linalg.norm(vec2)
127 if norm1 == 0 or norm2 == 0:
128 return 0.0
129 similarity = dot_product / (norm1 * norm2)
130 return float(np.clip(similarity, -1.0, 1.0))
131 except Exception: # pragma: no cover
132 return 0.0
135async def filter_by_vector_similarity(
136 detector: Any, documents: list[Any]
137) -> list[tuple]:
138 """Filter document pairs by vector similarity within configured band."""
139 similar_pairs: list[tuple] = []
140 if len(documents) < 2:
141 return similar_pairs
143 document_ids = [
144 getattr(doc, "document_id", f"{doc.source_type}:{doc.source_title}")
145 for doc in documents
146 ]
147 embeddings = await get_document_embeddings(detector, document_ids)
149 for i, doc1 in enumerate(documents):
150 for j, doc2 in enumerate(documents[i + 1 :], i + 1):
151 doc1_id = document_ids[i]
152 doc2_id = document_ids[j]
153 similarity_score = 0.0
154 if doc1_id in embeddings and doc2_id in embeddings:
155 similarity_score = calculate_vector_similarity(
156 detector, embeddings[doc1_id], embeddings[doc2_id]
157 )
158 if (
159 detector.MIN_VECTOR_SIMILARITY
160 <= similarity_score
161 <= detector.MAX_VECTOR_SIMILARITY
162 ):
163 similar_pairs.append((doc1, doc2, similarity_score))
165 similar_pairs.sort(key=lambda x: x[2], reverse=True)
166 return similar_pairs
169def should_analyze_for_conflicts(_detector: Any, doc1: Any, doc2: Any) -> bool:
170 """Pre-screen documents before conflict analysis."""
171 if not doc1 or not doc2:
172 return False
173 text1 = doc1.text if getattr(doc1, "text", None) else ""
174 text2 = doc2.text if getattr(doc2, "text", None) else ""
175 if len(text1.strip()) < 10 or len(text2.strip()) < 10:
176 return False
177 if hasattr(doc1, "document_id") and hasattr(doc2, "document_id"):
178 if doc1.document_id == doc2.document_id:
179 return False
180 if text1.strip() == text2.strip():
181 return False
182 return True
185async def get_tiered_analysis_pairs(detector: Any, documents: list[Any]) -> list[tuple]:
186 """Generate tiered analysis pairs for conflict detection (extracted)."""
187 pairs: list[tuple] = []
188 if len(documents) < 2:
189 return pairs
191 for i, doc1 in enumerate(documents):
192 for _j, doc2 in enumerate(documents[i + 1 :], i + 1):
193 score = 1.0
194 if hasattr(doc1, "score") and hasattr(doc2, "score"):
195 avg_doc_score = (doc1.score + doc2.score) / 2
196 score = min(1.0, avg_doc_score)
198 if score >= 0.8:
199 tier = "primary"
200 elif score >= 0.5:
201 tier = "secondary"
202 else:
203 tier = "tertiary"
205 pairs.append((doc1, doc2, tier, score))
207 pairs.sort(key=lambda x: x[3], reverse=True)
208 return pairs