Coverage for src / qdrant_loader_mcp_server / search / hybrid / components / cross_encoder_reranker.py: 88%
85 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:41 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-10 09:41 +0000
1"""
2Cross-encoder reranker to implement the WRRF algorithm.
4Uses the CrossEncoder class from the sentence-transformers library to rerank search results.
5"""
7from __future__ import annotations
9import logging
10import threading
11from typing import TYPE_CHECKING, Any
13if TYPE_CHECKING:
14 from sentence_transformers import CrossEncoder
17class CrossEncoderReranker:
18 """Re-rank search results using a cross-encoder model."""
20 def __init__(
21 self,
22 model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2",
23 device: str | None = None,
24 batch_size: int = 32,
25 enabled: bool = True,
26 ):
27 self.model_name = model_name
28 self.device = device if device is not None else self._get_optimal_device()
29 self.batch_size = batch_size
30 self.enabled = enabled
32 self.model: CrossEncoder | None = None
33 self._model_lock = threading.Lock()
34 self.logger = logging.getLogger(__name__)
36 if not self.enabled:
37 self.logger.info("Cross-encoder reranking disabled")
38 return
40 def _get_optimal_device(self) -> str:
41 try:
42 import torch
43 except ImportError:
44 return "cpu"
46 if torch.cuda.is_available():
47 return "cuda"
49 if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
50 return "mps"
52 return "cpu"
54 def _load_model(self) -> None:
56 if self.model is not None:
57 return
59 with self._model_lock:
60 if self.model is not None:
61 return
63 try:
64 from sentence_transformers import CrossEncoder
65 except ImportError:
66 self.logger.warning(
67 "sentence-transformers not installed. "
68 "Install with: pip install sentence-transformers"
69 )
70 self.enabled = False
71 return
73 try:
74 self.logger.info(
75 f"Loading cross-encoder model: {self.model_name} on {self.device}"
76 )
77 self.model = CrossEncoder(self.model_name, device=self.device)
78 self.logger.info("Cross-encoder model loaded")
79 except Exception as e:
80 self.logger.error(f"Failed to load cross-encoder model: {e}")
81 self.enabled = False
82 self.model = None
84 def rerank(
85 self,
86 query: str,
87 results: list[Any],
88 top_k: int | None = None,
89 text_key: str = "text",
90 ) -> list[Any]:
92 if not self.enabled or not results:
93 return results
95 self._load_model()
96 if self.model is None:
97 return results
99 try:
100 texts_with_indices = self._extract_texts_with_indices(results, text_key)
101 if not texts_with_indices:
102 return results
104 pairs = [(query, text) for (_idx, text) in texts_with_indices]
106 scores = self.model.predict(
107 pairs,
108 batch_size=self.batch_size,
109 show_progress_bar=False,
110 )
112 # Map scores back to original results using saved indices
113 mapped = [
114 (idx, results[idx], float(score))
115 for (idx, _), score in zip(texts_with_indices, scores, strict=False)
116 ]
118 ranked = sorted(
119 mapped,
120 key=lambda x: x[2],
121 reverse=True,
122 )
124 output = []
125 for rank, (_idx, item, score) in enumerate(ranked, start=1):
126 if isinstance(item, dict):
127 item["cross_encoder_score"] = float(score)
128 item["cross_encoder_rank"] = rank
129 else:
130 item.cross_encoder_score = float(score)
131 item.cross_encoder_rank = rank
132 item.score = float(score)
133 output.append(item)
135 return output if top_k is None else output[:top_k]
137 except Exception as e:
138 self.logger.error(f"Cross-encoder reranking failed: {e}")
139 return results
141 def _extract_texts_with_indices(
142 self, results: list[Any], text_key: str
143 ) -> list[tuple[int, str]]:
144 texts: list[tuple[int, str]] = []
146 for idx, r in enumerate(results):
147 if isinstance(r, dict):
148 text = r.get(text_key, "")
149 elif hasattr(r, text_key):
150 text = getattr(r, text_key)
151 else:
152 text = str(r)
154 text = str(text).strip()
155 if text:
156 texts.append((idx, text[:1000]))
158 return texts