Coverage for src / qdrant_loader_mcp_server / search / hybrid / components / cross_encoder_reranker.py: 88%
85 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-18 04:51 +0000
1"""
2Cross-encoder reranker to implement the WRRF algorithm.
4Uses the CrossEncoder class from the sentence-transformers library to rerank search results.
5"""
7from __future__ import annotations
9import logging
10import threading
11from typing import TYPE_CHECKING, Any
13if TYPE_CHECKING:
14 from sentence_transformers import CrossEncoder
16class CrossEncoderReranker:
17 """Re-rank search results using a cross-encoder model."""
19 def __init__(
20 self,
21 model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2",
22 device: str | None = None,
23 batch_size: int = 32,
24 enabled: bool = True,
25 ):
26 self.model_name = model_name
27 self.device = device if device is not None else self._get_optimal_device()
28 self.batch_size = batch_size
29 self.enabled = enabled
31 self.model: CrossEncoder | None = None
32 self._model_lock = threading.Lock()
33 self.logger = logging.getLogger(__name__)
35 if not self.enabled:
36 self.logger.info("Cross-encoder reranking disabled")
37 return
39 def _get_optimal_device(self) -> str:
40 try:
41 import torch
42 except ImportError:
43 return "cpu"
45 if torch.cuda.is_available():
46 return "cuda"
48 if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
49 return "mps"
51 return "cpu"
53 def _load_model(self) -> None:
55 if self.model is not None:
56 return
58 with self._model_lock:
59 if self.model is not None:
60 return
62 try:
63 from sentence_transformers import CrossEncoder
64 except ImportError:
65 self.logger.warning(
66 "sentence-transformers not installed. "
67 "Install with: pip install sentence-transformers"
68 )
69 self.enabled = False
70 return
72 try:
73 self.logger.info(
74 f"Loading cross-encoder model: {self.model_name} on {self.device}"
75 )
76 self.model = CrossEncoder(self.model_name, device=self.device)
77 self.logger.info("Cross-encoder model loaded")
78 except Exception as e:
79 self.logger.error(f"Failed to load cross-encoder model: {e}")
80 self.enabled = False
81 self.model = None
83 def rerank(
84 self,
85 query: str,
86 results: list[Any],
87 top_k: int | None = None,
88 text_key: str = "text",
89 ) -> list[Any]:
91 if not self.enabled or not results:
92 return results
94 self._load_model()
95 if self.model is None:
96 return results
98 try:
99 texts_with_indices = self._extract_texts_with_indices(results, text_key)
100 if not texts_with_indices:
101 return results
103 pairs = [(query, text) for (_idx, text) in texts_with_indices]
105 scores = self.model.predict(
106 pairs,
107 batch_size=self.batch_size,
108 show_progress_bar=False,
109 )
111 # Map scores back to original results using saved indices
112 mapped = [
113 (idx, results[idx], float(score))
114 for (idx, _), score in zip(texts_with_indices, scores, strict=False)
115 ]
117 ranked = sorted(
118 mapped,
119 key=lambda x: x[2],
120 reverse=True,
121 )
123 output = []
124 for rank, (_idx, item, score) in enumerate(ranked, start=1):
125 if isinstance(item, dict):
126 item["cross_encoder_score"] = float(score)
127 item["cross_encoder_rank"] = rank
128 else:
129 item.cross_encoder_score = float(score)
130 item.cross_encoder_rank = rank
131 item.score = float(score)
132 output.append(item)
134 return output if top_k is None else output[:top_k]
136 except Exception as e:
137 self.logger.error(f"Cross-encoder reranking failed: {e}")
138 return results
140 def _extract_texts_with_indices(
141 self, results: list[Any], text_key: str
142 ) -> list[tuple[int, str]]:
143 texts: list[tuple[int, str]] = []
145 for idx, r in enumerate(results):
146 if isinstance(r, dict):
147 text = r.get(text_key, "")
148 elif hasattr(r, text_key):
149 text = getattr(r, text_key)
150 else:
151 text = str(r)
153 text = str(text).strip()
154 if text:
155 texts.append((idx, text[:1000]))
157 return texts