Coverage for src / qdrant_loader_mcp_server / search / nlp / spacy_analyzer.py: 93%

183 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-18 04:51 +0000

1"""spaCy-powered query analysis for intelligent search.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass 

6from typing import Any 

7 

8# Note: spacy is imported at module level because: 

9# 1. This module (spacy_analyzer) is only imported when SpaCyQueryAnalyzer is needed 

10# 2. Tests need to be able to patch spacy.load 

11# The lazy loading happens at a higher level - this module is not imported at MCP startup 

12import spacy 

13from spacy.cli.download import download as spacy_download 

14from spacy.tokens import Doc 

15 

16from ...utils.logging import LoggingConfig 

17 

18logger = LoggingConfig.get_logger(__name__) 

19 

20 

21@dataclass 

22class QueryAnalysis: 

23 """Container for spaCy-based query analysis results.""" 

24 

25 # Core linguistic analysis 

26 entities: list[tuple[str, str]] # (text, label) 

27 pos_patterns: list[str] # Part-of-speech tags 

28 semantic_keywords: list[str] # Lemmatized, filtered keywords 

29 intent_signals: dict[str, Any] # Intent detection based on linguistic patterns 

30 main_concepts: list[str] # Noun chunks representing main concepts 

31 

32 # Semantic understanding 

33 query_vector: Any # spaCy Doc vector for similarity matching 

34 semantic_similarity_cache: dict[str, float] # Cache for similarity scores 

35 

36 # Query characteristics 

37 is_question: bool 

38 is_technical: bool 

39 complexity_score: float # 0-1 score based on linguistic complexity 

40 

41 # Processing metadata 

42 processed_tokens: int 

43 processing_time_ms: float 

44 

45 

46class SpaCyQueryAnalyzer: 

47 """Enhanced query analysis using spaCy NLP with en_core_web_md model.""" 

48 

49 def __init__(self, spacy_model: str = "en_core_web_md"): 

50 """Initialize the spaCy query analyzer. 

51 

52 Args: 

53 spacy_model: spaCy model to use (default: en_core_web_md with 20k word vectors) 

54 """ 

55 self.spacy_model = spacy_model 

56 self.nlp = self._load_spacy_model() 

57 self.logger = LoggingConfig.get_logger(__name__) 

58 

59 # Intent pattern definitions using POS tags and linguistic features 

60 self.intent_patterns = { 

61 "technical_lookup": { 

62 "entities": {"ORG", "PRODUCT", "PERSON", "GPE"}, 

63 "pos_sequences": [["NOUN", "NOUN"], ["ADJ", "NOUN"], ["VERB", "NOUN"]], 

64 "keywords": { 

65 "api", 

66 "database", 

67 "architecture", 

68 "implementation", 

69 "system", 

70 "code", 

71 "function", 

72 }, 

73 "question_words": set(), 

74 }, 

75 "business_context": { 

76 "entities": {"ORG", "MONEY", "PERCENT", "CARDINAL"}, 

77 "pos_sequences": [["NOUN", "NOUN"], ["ADJ", "NOUN", "NOUN"]], 

78 "keywords": { 

79 "requirements", 

80 "objectives", 

81 "strategy", 

82 "business", 

83 "scope", 

84 "goals", 

85 }, 

86 "question_words": {"what", "why", "how"}, 

87 }, 

88 "vendor_evaluation": { 

89 "entities": {"ORG", "MONEY", "PERSON"}, 

90 "pos_sequences": [["NOUN", "NOUN"], ["VERB", "NOUN"], ["ADJ", "NOUN"]], 

91 "keywords": { 

92 "proposal", 

93 "criteria", 

94 "cost", 

95 "vendor", 

96 "comparison", 

97 "evaluation", 

98 }, 

99 "question_words": {"which", "what", "how much"}, 

100 }, 

101 "procedural": { 

102 "entities": set(), 

103 "pos_sequences": [["VERB", "NOUN"], ["VERB", "DET", "NOUN"]], 

104 "keywords": { 

105 "how", 

106 "steps", 

107 "process", 

108 "procedure", 

109 "guide", 

110 "tutorial", 

111 }, 

112 "question_words": {"how", "when", "where"}, 

113 }, 

114 "informational": { 

115 "entities": set(), 

116 "pos_sequences": [["NOUN"], ["ADJ", "NOUN"]], 

117 "keywords": {"what", "definition", "meaning", "overview", "about"}, 

118 "question_words": {"what", "who", "when", "where"}, 

119 }, 

120 } 

121 

122 # Cache for processed queries to improve performance 

123 self._analysis_cache: dict[str, QueryAnalysis] = {} 

124 self._similarity_cache: dict[tuple[str, str], float] = {} 

125 

126 def _load_spacy_model(self) -> spacy.Language: 

127 """Load spaCy model with error handling and auto-download.""" 

128 try: 

129 nlp = spacy.load(self.spacy_model) 

130 # Verify model has vectors for semantic similarity 

131 if not nlp.meta.get("vectors", {}).get("vectors", 0): 

132 logger.warning( 

133 f"spaCy model {self.spacy_model} loaded but has no word vectors. " 

134 "Semantic similarity features will be limited." 

135 ) 

136 else: 

137 logger.info( 

138 f"spaCy model {self.spacy_model} loaded successfully with " 

139 f"{nlp.meta['vectors']['vectors']} word vectors" 

140 ) 

141 return nlp 

142 except OSError: 

143 logger.info(f"spaCy model {self.spacy_model} not found. Downloading...") 

144 try: 

145 spacy_download(self.spacy_model) 

146 nlp = spacy.load(self.spacy_model) 

147 logger.info(f"Successfully downloaded and loaded {self.spacy_model}") 

148 return nlp 

149 except Exception as e: 

150 logger.error(f"Failed to download spaCy model {self.spacy_model}: {e}") 

151 # Fallback to a basic model 

152 try: 

153 logger.warning("Falling back to en_core_web_sm model") 

154 spacy_download("en_core_web_sm") 

155 return spacy.load("en_core_web_sm") 

156 except Exception as fallback_error: 

157 logger.error(f"Failed to load fallback model: {fallback_error}") 

158 raise RuntimeError( 

159 f"Could not load any spaCy model. Please install {self.spacy_model} manually." 

160 ) 

161 

162 def analyze_query_semantic(self, query: str) -> QueryAnalysis: 

163 """Enhanced query analysis using spaCy NLP. 

164 

165 Args: 

166 query: The search query to analyze 

167 

168 Returns: 

169 QueryAnalysis containing comprehensive linguistic analysis 

170 """ 

171 import time 

172 

173 start_time = time.time() 

174 

175 # Check cache first 

176 if query in self._analysis_cache: 

177 cached = self._analysis_cache[query] 

178 logger.debug(f"Using cached analysis for query: {query[:50]}...") 

179 return cached 

180 

181 # Process query with spaCy 

182 doc = self.nlp(query) 

183 

184 # Extract entities with confidence 

185 entities = [(ent.text, ent.label_) for ent in doc.ents] 

186 

187 # Get POS patterns 

188 pos_patterns = [token.pos_ for token in doc if not token.is_space] 

189 

190 # Extract semantic keywords (lemmatized, filtered) 

191 semantic_keywords = [ 

192 token.lemma_.lower() 

193 for token in doc 

194 if ( 

195 token.is_alpha 

196 and not token.is_stop 

197 and not token.is_punct 

198 and len(token.text) > 2 

199 ) 

200 ] 

201 

202 # Extract main concepts (noun chunks) 

203 main_concepts = [ 

204 chunk.text.strip() 

205 for chunk in doc.noun_chunks 

206 if len(chunk.text.strip()) > 2 

207 ] 

208 

209 # Detect intent using linguistic patterns 

210 intent_signals = self._detect_intent_patterns( 

211 doc, entities, pos_patterns, semantic_keywords 

212 ) 

213 

214 # Query characteristics 

215 is_question = self._is_question(doc) 

216 is_technical = self._is_technical_query(doc, entities, semantic_keywords) 

217 complexity_score = self._calculate_complexity_score(doc) 

218 

219 # Processing metadata 

220 processing_time_ms = (time.time() - start_time) * 1000 

221 

222 # Create analysis result 

223 analysis = QueryAnalysis( 

224 entities=entities, 

225 pos_patterns=pos_patterns, 

226 semantic_keywords=semantic_keywords, 

227 intent_signals=intent_signals, 

228 main_concepts=main_concepts, 

229 query_vector=doc, # Store the spaCy Doc for similarity calculations 

230 semantic_similarity_cache={}, 

231 is_question=is_question, 

232 is_technical=is_technical, 

233 complexity_score=complexity_score, 

234 processed_tokens=len(doc), 

235 processing_time_ms=processing_time_ms, 

236 ) 

237 

238 # Cache the result 

239 self._analysis_cache[query] = analysis 

240 

241 logger.debug( 

242 f"Analyzed query in {processing_time_ms:.2f}ms", 

243 query_length=len(query), 

244 entities_found=len(entities), 

245 keywords_extracted=len(semantic_keywords), 

246 intent=intent_signals.get("primary_intent", "unknown"), 

247 ) 

248 

249 return analysis 

250 

251 def semantic_similarity_matching( 

252 self, query_analysis: QueryAnalysis, entity_text: str 

253 ) -> float: 

254 """Calculate semantic similarity using spaCy word vectors. 

255 

256 Args: 

257 query_analysis: Analyzed query containing the query vector 

258 entity_text: Text to compare similarity with 

259 

260 Returns: 

261 Similarity score between 0.0 and 1.0 

262 """ 

263 # Check cache first 

264 cache_key = (str(query_analysis.query_vector), entity_text) 

265 if cache_key in self._similarity_cache: 

266 return self._similarity_cache[cache_key] 

267 

268 try: 

269 # Process entity text 

270 entity_doc = self.nlp(entity_text) 

271 

272 # Calculate similarity using spaCy vectors 

273 if query_analysis.query_vector.has_vector and entity_doc.has_vector: 

274 similarity = query_analysis.query_vector.similarity(entity_doc) 

275 else: 

276 # Fallback to token-based similarity if no vectors 

277 similarity = self._token_similarity_fallback( 

278 query_analysis.semantic_keywords, entity_text.lower() 

279 ) 

280 

281 # Cache the result 

282 self._similarity_cache[cache_key] = similarity 

283 

284 return similarity 

285 

286 except Exception as e: 

287 logger.warning(f"Error calculating similarity for '{entity_text}': {e}") 

288 return 0.0 

289 

290 def _detect_intent_patterns( 

291 self, 

292 doc: Doc, 

293 entities: list[tuple[str, str]], 

294 pos_patterns: list[str], 

295 semantic_keywords: list[str], 

296 ) -> dict[str, Any]: 

297 """Detect query intent using POS patterns and linguistic features.""" 

298 intent_scores = {} 

299 

300 # Convert entities and keywords to sets for faster lookup 

301 entity_labels = {label for _, label in entities} 

302 keyword_set = set(semantic_keywords) 

303 

304 # Score each intent pattern 

305 for intent_name, pattern in self.intent_patterns.items(): 

306 score = 0.0 

307 

308 # Entity type matching 

309 entity_match = len(entity_labels.intersection(pattern["entities"])) / max( 

310 len(pattern["entities"]), 1 

311 ) 

312 score += entity_match * 0.3 

313 

314 # POS sequence matching 

315 pos_match = self._match_pos_sequences( 

316 pos_patterns, pattern["pos_sequences"] 

317 ) 

318 score += pos_match * 0.3 

319 

320 # Keyword matching 

321 keyword_match = len(keyword_set.intersection(pattern["keywords"])) / max( 

322 len(pattern["keywords"]), 1 

323 ) 

324 score += keyword_match * 0.2 

325 

326 # Question word matching 

327 question_match = self._match_question_words(doc, pattern["question_words"]) 

328 score += question_match * 0.2 

329 

330 intent_scores[intent_name] = score 

331 

332 # Find primary intent 

333 primary_intent = ( 

334 max(intent_scores, key=intent_scores.get) if intent_scores else "general" 

335 ) 

336 primary_score = intent_scores.get(primary_intent, 0.0) 

337 

338 # Only use intent if confidence is above threshold 

339 if primary_score < 0.3: 

340 primary_intent = "general" 

341 

342 return { 

343 "primary_intent": primary_intent, 

344 "confidence": primary_score, 

345 "all_scores": intent_scores, 

346 "linguistic_features": { 

347 "has_entities": len(entities) > 0, 

348 "has_question_words": any( 

349 token.text.lower() in {"what", "how", "why", "when", "who", "where"} 

350 for token in doc 

351 ), 

352 "verb_count": sum(1 for pos in pos_patterns if pos in {"VERB", "AUX"}), 

353 "noun_count": sum(1 for pos in pos_patterns if pos == "NOUN"), 

354 }, 

355 } 

356 

357 def _match_pos_sequences( 

358 self, pos_patterns: list[str], target_sequences: list[list[str]] 

359 ) -> float: 

360 """Match POS tag sequences in the query.""" 

361 if not target_sequences or not pos_patterns: 

362 return 0.0 

363 

364 matches = 0 

365 total_sequences = len(target_sequences) 

366 

367 for sequence in target_sequences: 

368 if self._contains_sequence(pos_patterns, sequence): 

369 matches += 1 

370 

371 return matches / total_sequences 

372 

373 def _contains_sequence(self, pos_patterns: list[str], sequence: list[str]) -> bool: 

374 """Check if POS patterns contain a specific sequence.""" 

375 if len(sequence) > len(pos_patterns): 

376 return False 

377 

378 for i in range(len(pos_patterns) - len(sequence) + 1): 

379 if pos_patterns[i : i + len(sequence)] == sequence: 

380 return True 

381 

382 return False 

383 

384 def _match_question_words(self, doc: Doc, question_words: set[str]) -> float: 

385 """Match question words in the query.""" 

386 if not question_words: 

387 return 0.0 

388 

389 found_words = { 

390 token.text.lower() for token in doc if token.text.lower() in question_words 

391 } 

392 return len(found_words) / len(question_words) 

393 

394 def _is_question(self, doc: Doc) -> bool: 

395 """Detect if the query is a question using linguistic features.""" 

396 # Check for question marks 

397 if "?" in doc.text: 

398 return True 

399 

400 # Check for question words at the beginning 

401 question_words = { 

402 "what", 

403 "how", 

404 "why", 

405 "when", 

406 "who", 

407 "where", 

408 "which", 

409 "whose", 

410 "whom", 

411 } 

412 first_token = doc[0] if doc else None 

413 if first_token and first_token.text.lower() in question_words: 

414 return True 

415 

416 # Check for auxiliary verbs at the beginning (e.g., "Can you", "Do we", "Is there") 

417 if len(doc) >= 2: 

418 first_two = [token.text.lower() for token in doc[:2]] 

419 aux_patterns = { 

420 ("can", "you"), 

421 ("do", "we"), 

422 ("is", "there"), 

423 ("are", "there"), 

424 ("will", "you"), 

425 } 

426 if tuple(first_two) in aux_patterns: 

427 return True 

428 

429 return False 

430 

431 def _is_technical_query( 

432 self, doc: Doc, entities: list[tuple[str, str]], keywords: list[str] 

433 ) -> bool: 

434 """Detect if the query is technical in nature.""" 

435 technical_indicators = { 

436 "api", 

437 "database", 

438 "system", 

439 "code", 

440 "function", 

441 "architecture", 

442 "implementation", 

443 "framework", 

444 "library", 

445 "server", 

446 "client", 

447 "protocol", 

448 "algorithm", 

449 "data", 

450 "query", 

451 "schema", 

452 "endpoint", 

453 } 

454 

455 # Check keywords 

456 keyword_set = set(keywords) 

457 if keyword_set.intersection(technical_indicators): 

458 return True 

459 

460 # Check for technical entity types 

461 technical_entities = { 

462 "ORG", 

463 "PRODUCT", 

464 "LANGUAGE", 

465 } # Often technical in this context 

466 entity_labels = {label for _, label in entities} 

467 if entity_labels.intersection(technical_entities): 

468 return True 

469 

470 return False 

471 

472 def _calculate_complexity_score(self, doc: Doc) -> float: 

473 """Calculate query complexity based on linguistic features.""" 

474 if not doc: 

475 return 0.0 

476 

477 # Factors that contribute to complexity 

478 factors = { 

479 "length": min(len(doc) / 20, 1.0), # Longer queries are more complex 

480 "entities": min(len(doc.ents) / 5, 1.0), # More entities = more complex 

481 "noun_chunks": min( 

482 len(list(doc.noun_chunks)) / 5, 1.0 

483 ), # More concepts = more complex 

484 "question_words": min( 

485 sum( 

486 1 

487 for token in doc 

488 if token.text.lower() 

489 in {"what", "how", "why", "when", "who", "where", "which"} 

490 ) 

491 / 3, 

492 1.0, 

493 ), 

494 "dependency_depth": min(self._max_dependency_depth(doc) / 5, 1.0), 

495 } 

496 

497 # Weighted average 

498 weights = { 

499 "length": 0.2, 

500 "entities": 0.3, 

501 "noun_chunks": 0.2, 

502 "question_words": 0.15, 

503 "dependency_depth": 0.15, 

504 } 

505 

506 complexity = sum(factors[key] * weights[key] for key in factors) 

507 return min(complexity, 1.0) 

508 

509 def _max_dependency_depth(self, doc: Doc) -> int: 

510 """Calculate maximum dependency tree depth.""" 

511 max_depth = 0 

512 

513 def get_depth(token, current_depth=0): 

514 nonlocal max_depth 

515 max_depth = max(max_depth, current_depth) 

516 for child in token.children: 

517 get_depth(child, current_depth + 1) 

518 

519 for token in doc: 

520 if token.head == token: # Root token 

521 get_depth(token) 

522 

523 return max_depth 

524 

525 def _token_similarity_fallback( 

526 self, query_keywords: list[str], entity_text: str 

527 ) -> float: 

528 """Fallback similarity calculation when word vectors are unavailable.""" 

529 if not query_keywords: 

530 return 0.0 

531 

532 entity_words = set(entity_text.lower().split()) 

533 query_word_set = set(query_keywords) 

534 

535 # Simple Jaccard similarity 

536 intersection = query_word_set.intersection(entity_words) 

537 union = query_word_set.union(entity_words) 

538 

539 return len(intersection) / len(union) if union else 0.0 

540 

541 def clear_cache(self): 

542 """Clear analysis and similarity caches.""" 

543 self._analysis_cache.clear() 

544 self._similarity_cache.clear() 

545 logger.debug("Cleared spaCy analyzer caches") 

546 

547 def get_cache_stats(self) -> dict[str, int]: 

548 """Get cache statistics for monitoring.""" 

549 return { 

550 "analysis_cache_size": len(self._analysis_cache), 

551 "similarity_cache_size": len(self._similarity_cache), 

552 }