Coverage for src / qdrant_loader_core / sparse / bm25.py: 59%
59 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-11 09:34 +0000
1from __future__ import annotations
3import hashlib
4import math
5import re
6from collections import Counter
7from dataclasses import dataclass
8from functools import cache
10_TOKEN_RE = re.compile(r"[A-Za-z0-9_]{2,}")
11_DEFAULT_HASH_MOD = 2_147_483_647
12_DEFAULT_STOP_WORDS: frozenset[str] = frozenset(
13 {
14 "a",
15 "an",
16 "and",
17 "are",
18 "as",
19 "at",
20 "be",
21 "by",
22 "for",
23 "from",
24 "has",
25 "he",
26 "in",
27 "is",
28 "it",
29 "its",
30 "of",
31 "on",
32 "that",
33 "the",
34 "to",
35 "was",
36 "were",
37 "will",
38 "with",
39 }
40)
43@dataclass(frozen=True)
44class SparseVectorData:
45 """Simple sparse vector representation."""
47 indices: list[int]
48 values: list[float]
50 def is_empty(self) -> bool:
51 return not self.indices
54class BM25SparseEncoder:
55 """Deterministic BM25-family sparse encoder using hashed token IDs."""
57 def __init__(
58 self,
59 model: str = "bm25",
60 *,
61 hash_mod: int = _DEFAULT_HASH_MOD,
62 stop_words: set[str] | frozenset[str] | None = None,
63 ):
64 self.model = (model or "bm25").strip().lower()
65 self.hash_mod = max(10_000, int(hash_mod))
66 # Snapshot into a frozenset so callers can't mutate the cached encoder's
67 # behavior after construction, and normalize casing for safety.
68 source = stop_words if stop_words is not None else _DEFAULT_STOP_WORDS
69 self.stop_words: frozenset[str] = frozenset(s.lower().strip() for s in source)
71 if self.model in {"bm25_lite", "bm25-lite"}:
72 self.k1 = 0.9
73 else:
74 self.k1 = 1.2
76 def _tokenize(self, text: str) -> list[str]:
77 if not text:
78 return []
79 tokens = _TOKEN_RE.findall(text.lower())
80 return [tok for tok in tokens if tok not in self.stop_words]
82 def _token_to_index(self, token: str) -> int:
83 digest = hashlib.blake2b(token.encode("utf-8"), digest_size=8).digest()
84 index = int.from_bytes(digest, byteorder="big", signed=False) % self.hash_mod
85 return index + 1
87 def _encode_with_weights(self, text: str, *, query_mode: bool) -> SparseVectorData:
88 # Document mode applies BM25 TF saturation without the IDF term:
89 # standard BM25 is IDF * (tf * (k1 + 1)) / (tf + k1); this encoder uses
90 # only (tf * (k1 + 1)) / (tf + k1) because no corpus statistics are
91 # available at encode time (hashed token IDs, single-document context).
92 # Qdrant computes the sparse dot product server-side, so omitting IDF
93 # keeps encoding deterministic and stateless at the cost of not
94 # down-weighting common terms — acceptable since query-side weighting
95 # (1 + log tf) and dense retrieval compensate in the hybrid pipeline.
96 tokens = self._tokenize(text)
97 if not tokens:
98 return SparseVectorData(indices=[], values=[])
100 counts = Counter(tokens)
101 weighted: dict[int, float] = {}
102 for token, tf in counts.items():
103 index = self._token_to_index(token)
104 if query_mode:
105 weight = 1.0 + math.log(float(tf))
106 else:
107 tf_f = float(tf)
108 weight = (tf_f * (self.k1 + 1.0)) / (tf_f + self.k1)
109 weighted[index] = weighted.get(index, 0.0) + weight
111 ordered = sorted(weighted.items(), key=lambda kv: kv[0])
112 return SparseVectorData(
113 indices=[idx for idx, _ in ordered],
114 values=[float(val) for _, val in ordered],
115 )
117 def encode_document(self, text: str) -> SparseVectorData:
118 return self._encode_with_weights(text, query_mode=False)
120 def encode_query(self, text: str) -> SparseVectorData:
121 return self._encode_with_weights(text, query_mode=True)
124@cache
125def _get_sparse_encoder_cached(normalized_model: str) -> BM25SparseEncoder:
126 return BM25SparseEncoder(model=normalized_model)
129def get_sparse_encoder(model: str) -> BM25SparseEncoder:
130 """Return a process-wide :class:`BM25SparseEncoder` for ``model``.
132 Encoders are deterministic and hold no mutable state, so a single instance
133 per model name is safe to share across the loader and the MCP server.
134 The model name is normalized (stripped, lower-cased) before caching so
135 equivalent inputs like ``"BM25"`` and ``" bm25 "`` collapse to one entry.
136 """
137 normalized = (model or "bm25").strip().lower()
138 return _get_sparse_encoder_cached(normalized)