Coverage for src / qdrant_loader_mcp_server / search / hybrid / components / cross_encoder_reranker.py: 88%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 04:51 +0000

1""" 

2Cross-encoder reranker to implement the WRRF algorithm. 

3 

4Uses the CrossEncoder class from the sentence-transformers library to rerank search results. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import threading 

11from typing import TYPE_CHECKING, Any 

12 

13if TYPE_CHECKING: 

14 from sentence_transformers import CrossEncoder 

15 

16class CrossEncoderReranker: 

17 """Re-rank search results using a cross-encoder model.""" 

18 

19 def __init__( 

20 self, 

21 model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2", 

22 device: str | None = None, 

23 batch_size: int = 32, 

24 enabled: bool = True, 

25 ): 

26 self.model_name = model_name 

27 self.device = device if device is not None else self._get_optimal_device() 

28 self.batch_size = batch_size 

29 self.enabled = enabled 

30 

31 self.model: CrossEncoder | None = None 

32 self._model_lock = threading.Lock() 

33 self.logger = logging.getLogger(__name__) 

34 

35 if not self.enabled: 

36 self.logger.info("Cross-encoder reranking disabled") 

37 return 

38 

39 def _get_optimal_device(self) -> str: 

40 try: 

41 import torch 

42 except ImportError: 

43 return "cpu" 

44 

45 if torch.cuda.is_available(): 

46 return "cuda" 

47 

48 if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 

49 return "mps" 

50 

51 return "cpu" 

52 

53 def _load_model(self) -> None: 

54 

55 if self.model is not None: 

56 return 

57 

58 with self._model_lock: 

59 if self.model is not None: 

60 return 

61 

62 try: 

63 from sentence_transformers import CrossEncoder 

64 except ImportError: 

65 self.logger.warning( 

66 "sentence-transformers not installed. " 

67 "Install with: pip install sentence-transformers" 

68 ) 

69 self.enabled = False 

70 return 

71 

72 try: 

73 self.logger.info( 

74 f"Loading cross-encoder model: {self.model_name} on {self.device}" 

75 ) 

76 self.model = CrossEncoder(self.model_name, device=self.device) 

77 self.logger.info("Cross-encoder model loaded") 

78 except Exception as e: 

79 self.logger.error(f"Failed to load cross-encoder model: {e}") 

80 self.enabled = False 

81 self.model = None 

82 

83 def rerank( 

84 self, 

85 query: str, 

86 results: list[Any], 

87 top_k: int | None = None, 

88 text_key: str = "text", 

89 ) -> list[Any]: 

90 

91 if not self.enabled or not results: 

92 return results 

93 

94 self._load_model() 

95 if self.model is None: 

96 return results 

97 

98 try: 

99 texts_with_indices = self._extract_texts_with_indices(results, text_key) 

100 if not texts_with_indices: 

101 return results 

102 

103 pairs = [(query, text) for (_idx, text) in texts_with_indices] 

104 

105 scores = self.model.predict( 

106 pairs, 

107 batch_size=self.batch_size, 

108 show_progress_bar=False, 

109 ) 

110 

111 # Map scores back to original results using saved indices 

112 mapped = [ 

113 (idx, results[idx], float(score)) 

114 for (idx, _), score in zip(texts_with_indices, scores, strict=False) 

115 ] 

116 

117 ranked = sorted( 

118 mapped, 

119 key=lambda x: x[2], 

120 reverse=True, 

121 ) 

122 

123 output = [] 

124 for rank, (_idx, item, score) in enumerate(ranked, start=1): 

125 if isinstance(item, dict): 

126 item["cross_encoder_score"] = float(score) 

127 item["cross_encoder_rank"] = rank 

128 else: 

129 item.cross_encoder_score = float(score) 

130 item.cross_encoder_rank = rank 

131 item.score = float(score) 

132 output.append(item) 

133 

134 return output if top_k is None else output[:top_k] 

135 

136 except Exception as e: 

137 self.logger.error(f"Cross-encoder reranking failed: {e}") 

138 return results 

139 

140 def _extract_texts_with_indices( 

141 self, results: list[Any], text_key: str 

142 ) -> list[tuple[int, str]]: 

143 texts: list[tuple[int, str]] = [] 

144 

145 for idx, r in enumerate(results): 

146 if isinstance(r, dict): 

147 text = r.get(text_key, "") 

148 elif hasattr(r, text_key): 

149 text = getattr(r, text_key) 

150 else: 

151 text = str(r) 

152 

153 text = str(text).strip() 

154 if text: 

155 texts.append((idx, text[:1000])) 

156 

157 return texts