Coverage for src / qdrant_loader_mcp_server / search / hybrid / components / cross_encoder_reranker.py: 88%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-10 09:41 +0000

1""" 

2Cross-encoder reranker to implement the WRRF algorithm. 

3 

4Uses the CrossEncoder class from the sentence-transformers library to rerank search results. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import threading 

11from typing import TYPE_CHECKING, Any 

12 

13if TYPE_CHECKING: 

14 from sentence_transformers import CrossEncoder 

15 

16 

17class CrossEncoderReranker: 

18 """Re-rank search results using a cross-encoder model.""" 

19 

20 def __init__( 

21 self, 

22 model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2", 

23 device: str | None = None, 

24 batch_size: int = 32, 

25 enabled: bool = True, 

26 ): 

27 self.model_name = model_name 

28 self.device = device if device is not None else self._get_optimal_device() 

29 self.batch_size = batch_size 

30 self.enabled = enabled 

31 

32 self.model: CrossEncoder | None = None 

33 self._model_lock = threading.Lock() 

34 self.logger = logging.getLogger(__name__) 

35 

36 if not self.enabled: 

37 self.logger.info("Cross-encoder reranking disabled") 

38 return 

39 

40 def _get_optimal_device(self) -> str: 

41 try: 

42 import torch 

43 except ImportError: 

44 return "cpu" 

45 

46 if torch.cuda.is_available(): 

47 return "cuda" 

48 

49 if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 

50 return "mps" 

51 

52 return "cpu" 

53 

54 def _load_model(self) -> None: 

55 

56 if self.model is not None: 

57 return 

58 

59 with self._model_lock: 

60 if self.model is not None: 

61 return 

62 

63 try: 

64 from sentence_transformers import CrossEncoder 

65 except ImportError: 

66 self.logger.warning( 

67 "sentence-transformers not installed. " 

68 "Install with: pip install sentence-transformers" 

69 ) 

70 self.enabled = False 

71 return 

72 

73 try: 

74 self.logger.info( 

75 f"Loading cross-encoder model: {self.model_name} on {self.device}" 

76 ) 

77 self.model = CrossEncoder(self.model_name, device=self.device) 

78 self.logger.info("Cross-encoder model loaded") 

79 except Exception as e: 

80 self.logger.error(f"Failed to load cross-encoder model: {e}") 

81 self.enabled = False 

82 self.model = None 

83 

84 def rerank( 

85 self, 

86 query: str, 

87 results: list[Any], 

88 top_k: int | None = None, 

89 text_key: str = "text", 

90 ) -> list[Any]: 

91 

92 if not self.enabled or not results: 

93 return results 

94 

95 self._load_model() 

96 if self.model is None: 

97 return results 

98 

99 try: 

100 texts_with_indices = self._extract_texts_with_indices(results, text_key) 

101 if not texts_with_indices: 

102 return results 

103 

104 pairs = [(query, text) for (_idx, text) in texts_with_indices] 

105 

106 scores = self.model.predict( 

107 pairs, 

108 batch_size=self.batch_size, 

109 show_progress_bar=False, 

110 ) 

111 

112 # Map scores back to original results using saved indices 

113 mapped = [ 

114 (idx, results[idx], float(score)) 

115 for (idx, _), score in zip(texts_with_indices, scores, strict=False) 

116 ] 

117 

118 ranked = sorted( 

119 mapped, 

120 key=lambda x: x[2], 

121 reverse=True, 

122 ) 

123 

124 output = [] 

125 for rank, (_idx, item, score) in enumerate(ranked, start=1): 

126 if isinstance(item, dict): 

127 item["cross_encoder_score"] = float(score) 

128 item["cross_encoder_rank"] = rank 

129 else: 

130 item.cross_encoder_score = float(score) 

131 item.cross_encoder_rank = rank 

132 item.score = float(score) 

133 output.append(item) 

134 

135 return output if top_k is None else output[:top_k] 

136 

137 except Exception as e: 

138 self.logger.error(f"Cross-encoder reranking failed: {e}") 

139 return results 

140 

141 def _extract_texts_with_indices( 

142 self, results: list[Any], text_key: str 

143 ) -> list[tuple[int, str]]: 

144 texts: list[tuple[int, str]] = [] 

145 

146 for idx, r in enumerate(results): 

147 if isinstance(r, dict): 

148 text = r.get(text_key, "") 

149 elif hasattr(r, text_key): 

150 text = getattr(r, text_key) 

151 else: 

152 text = str(r) 

153 

154 text = str(text).strip() 

155 if text: 

156 texts.append((idx, text[:1000])) 

157 

158 return texts