Coverage for src/qdrant_loader_mcp_server/search/nlp/spacy_analyzer.py: 93%

182 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1"""spaCy-powered query analysis for intelligent search.""" 

2 

3from dataclasses import dataclass 

4from typing import Any 

5 

6import spacy 

7from spacy.cli.download import download as spacy_download 

8from spacy.tokens import Doc 

9 

10from ...utils.logging import LoggingConfig 

11 

12logger = LoggingConfig.get_logger(__name__) 

13 

14 

15@dataclass 

16class QueryAnalysis: 

17 """Container for spaCy-based query analysis results.""" 

18 

19 # Core linguistic analysis 

20 entities: list[tuple[str, str]] # (text, label) 

21 pos_patterns: list[str] # Part-of-speech tags 

22 semantic_keywords: list[str] # Lemmatized, filtered keywords 

23 intent_signals: dict[str, Any] # Intent detection based on linguistic patterns 

24 main_concepts: list[str] # Noun chunks representing main concepts 

25 

26 # Semantic understanding 

27 query_vector: Any # spaCy Doc vector for similarity matching 

28 semantic_similarity_cache: dict[str, float] # Cache for similarity scores 

29 

30 # Query characteristics 

31 is_question: bool 

32 is_technical: bool 

33 complexity_score: float # 0-1 score based on linguistic complexity 

34 

35 # Processing metadata 

36 processed_tokens: int 

37 processing_time_ms: float 

38 

39 

40class SpaCyQueryAnalyzer: 

41 """Enhanced query analysis using spaCy NLP with en_core_web_md model.""" 

42 

43 def __init__(self, spacy_model: str = "en_core_web_md"): 

44 """Initialize the spaCy query analyzer. 

45 

46 Args: 

47 spacy_model: spaCy model to use (default: en_core_web_md with 20k word vectors) 

48 """ 

49 self.spacy_model = spacy_model 

50 self.nlp = self._load_spacy_model() 

51 self.logger = LoggingConfig.get_logger(__name__) 

52 

53 # Intent pattern definitions using POS tags and linguistic features 

54 self.intent_patterns = { 

55 "technical_lookup": { 

56 "entities": {"ORG", "PRODUCT", "PERSON", "GPE"}, 

57 "pos_sequences": [["NOUN", "NOUN"], ["ADJ", "NOUN"], ["VERB", "NOUN"]], 

58 "keywords": { 

59 "api", 

60 "database", 

61 "architecture", 

62 "implementation", 

63 "system", 

64 "code", 

65 "function", 

66 }, 

67 "question_words": set(), 

68 }, 

69 "business_context": { 

70 "entities": {"ORG", "MONEY", "PERCENT", "CARDINAL"}, 

71 "pos_sequences": [["NOUN", "NOUN"], ["ADJ", "NOUN", "NOUN"]], 

72 "keywords": { 

73 "requirements", 

74 "objectives", 

75 "strategy", 

76 "business", 

77 "scope", 

78 "goals", 

79 }, 

80 "question_words": {"what", "why", "how"}, 

81 }, 

82 "vendor_evaluation": { 

83 "entities": {"ORG", "MONEY", "PERSON"}, 

84 "pos_sequences": [["NOUN", "NOUN"], ["VERB", "NOUN"], ["ADJ", "NOUN"]], 

85 "keywords": { 

86 "proposal", 

87 "criteria", 

88 "cost", 

89 "vendor", 

90 "comparison", 

91 "evaluation", 

92 }, 

93 "question_words": {"which", "what", "how much"}, 

94 }, 

95 "procedural": { 

96 "entities": set(), 

97 "pos_sequences": [["VERB", "NOUN"], ["VERB", "DET", "NOUN"]], 

98 "keywords": { 

99 "how", 

100 "steps", 

101 "process", 

102 "procedure", 

103 "guide", 

104 "tutorial", 

105 }, 

106 "question_words": {"how", "when", "where"}, 

107 }, 

108 "informational": { 

109 "entities": set(), 

110 "pos_sequences": [["NOUN"], ["ADJ", "NOUN"]], 

111 "keywords": {"what", "definition", "meaning", "overview", "about"}, 

112 "question_words": {"what", "who", "when", "where"}, 

113 }, 

114 } 

115 

116 # Cache for processed queries to improve performance 

117 self._analysis_cache: dict[str, QueryAnalysis] = {} 

118 self._similarity_cache: dict[tuple[str, str], float] = {} 

119 

120 def _load_spacy_model(self) -> spacy.Language: 

121 """Load spaCy model with error handling and auto-download.""" 

122 try: 

123 nlp = spacy.load(self.spacy_model) 

124 # Verify model has vectors for semantic similarity 

125 if not nlp.meta.get("vectors", {}).get("vectors", 0): 

126 logger.warning( 

127 f"spaCy model {self.spacy_model} loaded but has no word vectors. " 

128 "Semantic similarity features will be limited." 

129 ) 

130 else: 

131 logger.info( 

132 f"spaCy model {self.spacy_model} loaded successfully with " 

133 f"{nlp.meta['vectors']['vectors']} word vectors" 

134 ) 

135 return nlp 

136 except OSError: 

137 logger.info(f"spaCy model {self.spacy_model} not found. Downloading...") 

138 try: 

139 spacy_download(self.spacy_model) 

140 nlp = spacy.load(self.spacy_model) 

141 logger.info(f"Successfully downloaded and loaded {self.spacy_model}") 

142 return nlp 

143 except Exception as e: 

144 logger.error(f"Failed to download spaCy model {self.spacy_model}: {e}") 

145 # Fallback to a basic model 

146 try: 

147 logger.warning("Falling back to en_core_web_sm model") 

148 spacy_download("en_core_web_sm") 

149 return spacy.load("en_core_web_sm") 

150 except Exception as fallback_error: 

151 logger.error(f"Failed to load fallback model: {fallback_error}") 

152 raise RuntimeError( 

153 f"Could not load any spaCy model. Please install {self.spacy_model} manually." 

154 ) 

155 

156 def analyze_query_semantic(self, query: str) -> QueryAnalysis: 

157 """Enhanced query analysis using spaCy NLP. 

158 

159 Args: 

160 query: The search query to analyze 

161 

162 Returns: 

163 QueryAnalysis containing comprehensive linguistic analysis 

164 """ 

165 import time 

166 

167 start_time = time.time() 

168 

169 # Check cache first 

170 if query in self._analysis_cache: 

171 cached = self._analysis_cache[query] 

172 logger.debug(f"Using cached analysis for query: {query[:50]}...") 

173 return cached 

174 

175 # Process query with spaCy 

176 doc = self.nlp(query) 

177 

178 # Extract entities with confidence 

179 entities = [(ent.text, ent.label_) for ent in doc.ents] 

180 

181 # Get POS patterns 

182 pos_patterns = [token.pos_ for token in doc if not token.is_space] 

183 

184 # Extract semantic keywords (lemmatized, filtered) 

185 semantic_keywords = [ 

186 token.lemma_.lower() 

187 for token in doc 

188 if ( 

189 token.is_alpha 

190 and not token.is_stop 

191 and not token.is_punct 

192 and len(token.text) > 2 

193 ) 

194 ] 

195 

196 # Extract main concepts (noun chunks) 

197 main_concepts = [ 

198 chunk.text.strip() 

199 for chunk in doc.noun_chunks 

200 if len(chunk.text.strip()) > 2 

201 ] 

202 

203 # Detect intent using linguistic patterns 

204 intent_signals = self._detect_intent_patterns( 

205 doc, entities, pos_patterns, semantic_keywords 

206 ) 

207 

208 # Query characteristics 

209 is_question = self._is_question(doc) 

210 is_technical = self._is_technical_query(doc, entities, semantic_keywords) 

211 complexity_score = self._calculate_complexity_score(doc) 

212 

213 # Processing metadata 

214 processing_time_ms = (time.time() - start_time) * 1000 

215 

216 # Create analysis result 

217 analysis = QueryAnalysis( 

218 entities=entities, 

219 pos_patterns=pos_patterns, 

220 semantic_keywords=semantic_keywords, 

221 intent_signals=intent_signals, 

222 main_concepts=main_concepts, 

223 query_vector=doc, # Store the spaCy Doc for similarity calculations 

224 semantic_similarity_cache={}, 

225 is_question=is_question, 

226 is_technical=is_technical, 

227 complexity_score=complexity_score, 

228 processed_tokens=len(doc), 

229 processing_time_ms=processing_time_ms, 

230 ) 

231 

232 # Cache the result 

233 self._analysis_cache[query] = analysis 

234 

235 logger.debug( 

236 f"Analyzed query in {processing_time_ms:.2f}ms", 

237 query_length=len(query), 

238 entities_found=len(entities), 

239 keywords_extracted=len(semantic_keywords), 

240 intent=intent_signals.get("primary_intent", "unknown"), 

241 ) 

242 

243 return analysis 

244 

245 def semantic_similarity_matching( 

246 self, query_analysis: QueryAnalysis, entity_text: str 

247 ) -> float: 

248 """Calculate semantic similarity using spaCy word vectors. 

249 

250 Args: 

251 query_analysis: Analyzed query containing the query vector 

252 entity_text: Text to compare similarity with 

253 

254 Returns: 

255 Similarity score between 0.0 and 1.0 

256 """ 

257 # Check cache first 

258 cache_key = (str(query_analysis.query_vector), entity_text) 

259 if cache_key in self._similarity_cache: 

260 return self._similarity_cache[cache_key] 

261 

262 try: 

263 # Process entity text 

264 entity_doc = self.nlp(entity_text) 

265 

266 # Calculate similarity using spaCy vectors 

267 if query_analysis.query_vector.has_vector and entity_doc.has_vector: 

268 similarity = query_analysis.query_vector.similarity(entity_doc) 

269 else: 

270 # Fallback to token-based similarity if no vectors 

271 similarity = self._token_similarity_fallback( 

272 query_analysis.semantic_keywords, entity_text.lower() 

273 ) 

274 

275 # Cache the result 

276 self._similarity_cache[cache_key] = similarity 

277 

278 return similarity 

279 

280 except Exception as e: 

281 logger.warning(f"Error calculating similarity for '{entity_text}': {e}") 

282 return 0.0 

283 

284 def _detect_intent_patterns( 

285 self, 

286 doc: Doc, 

287 entities: list[tuple[str, str]], 

288 pos_patterns: list[str], 

289 semantic_keywords: list[str], 

290 ) -> dict[str, Any]: 

291 """Detect query intent using POS patterns and linguistic features.""" 

292 intent_scores = {} 

293 

294 # Convert entities and keywords to sets for faster lookup 

295 entity_labels = {label for _, label in entities} 

296 keyword_set = set(semantic_keywords) 

297 

298 # Score each intent pattern 

299 for intent_name, pattern in self.intent_patterns.items(): 

300 score = 0.0 

301 

302 # Entity type matching 

303 entity_match = len(entity_labels.intersection(pattern["entities"])) / max( 

304 len(pattern["entities"]), 1 

305 ) 

306 score += entity_match * 0.3 

307 

308 # POS sequence matching 

309 pos_match = self._match_pos_sequences( 

310 pos_patterns, pattern["pos_sequences"] 

311 ) 

312 score += pos_match * 0.3 

313 

314 # Keyword matching 

315 keyword_match = len(keyword_set.intersection(pattern["keywords"])) / max( 

316 len(pattern["keywords"]), 1 

317 ) 

318 score += keyword_match * 0.2 

319 

320 # Question word matching 

321 question_match = self._match_question_words(doc, pattern["question_words"]) 

322 score += question_match * 0.2 

323 

324 intent_scores[intent_name] = score 

325 

326 # Find primary intent 

327 primary_intent = ( 

328 max(intent_scores, key=intent_scores.get) if intent_scores else "general" 

329 ) 

330 primary_score = intent_scores.get(primary_intent, 0.0) 

331 

332 # Only use intent if confidence is above threshold 

333 if primary_score < 0.3: 

334 primary_intent = "general" 

335 

336 return { 

337 "primary_intent": primary_intent, 

338 "confidence": primary_score, 

339 "all_scores": intent_scores, 

340 "linguistic_features": { 

341 "has_entities": len(entities) > 0, 

342 "has_question_words": any( 

343 token.text.lower() in {"what", "how", "why", "when", "who", "where"} 

344 for token in doc 

345 ), 

346 "verb_count": sum(1 for pos in pos_patterns if pos in {"VERB", "AUX"}), 

347 "noun_count": sum(1 for pos in pos_patterns if pos == "NOUN"), 

348 }, 

349 } 

350 

351 def _match_pos_sequences( 

352 self, pos_patterns: list[str], target_sequences: list[list[str]] 

353 ) -> float: 

354 """Match POS tag sequences in the query.""" 

355 if not target_sequences or not pos_patterns: 

356 return 0.0 

357 

358 matches = 0 

359 total_sequences = len(target_sequences) 

360 

361 for sequence in target_sequences: 

362 if self._contains_sequence(pos_patterns, sequence): 

363 matches += 1 

364 

365 return matches / total_sequences 

366 

367 def _contains_sequence(self, pos_patterns: list[str], sequence: list[str]) -> bool: 

368 """Check if POS patterns contain a specific sequence.""" 

369 if len(sequence) > len(pos_patterns): 

370 return False 

371 

372 for i in range(len(pos_patterns) - len(sequence) + 1): 

373 if pos_patterns[i : i + len(sequence)] == sequence: 

374 return True 

375 

376 return False 

377 

378 def _match_question_words(self, doc: Doc, question_words: set[str]) -> float: 

379 """Match question words in the query.""" 

380 if not question_words: 

381 return 0.0 

382 

383 found_words = { 

384 token.text.lower() for token in doc if token.text.lower() in question_words 

385 } 

386 return len(found_words) / len(question_words) 

387 

388 def _is_question(self, doc: Doc) -> bool: 

389 """Detect if the query is a question using linguistic features.""" 

390 # Check for question marks 

391 if "?" in doc.text: 

392 return True 

393 

394 # Check for question words at the beginning 

395 question_words = { 

396 "what", 

397 "how", 

398 "why", 

399 "when", 

400 "who", 

401 "where", 

402 "which", 

403 "whose", 

404 "whom", 

405 } 

406 first_token = doc[0] if doc else None 

407 if first_token and first_token.text.lower() in question_words: 

408 return True 

409 

410 # Check for auxiliary verbs at the beginning (e.g., "Can you", "Do we", "Is there") 

411 if len(doc) >= 2: 

412 first_two = [token.text.lower() for token in doc[:2]] 

413 aux_patterns = { 

414 ("can", "you"), 

415 ("do", "we"), 

416 ("is", "there"), 

417 ("are", "there"), 

418 ("will", "you"), 

419 } 

420 if tuple(first_two) in aux_patterns: 

421 return True 

422 

423 return False 

424 

425 def _is_technical_query( 

426 self, doc: Doc, entities: list[tuple[str, str]], keywords: list[str] 

427 ) -> bool: 

428 """Detect if the query is technical in nature.""" 

429 technical_indicators = { 

430 "api", 

431 "database", 

432 "system", 

433 "code", 

434 "function", 

435 "architecture", 

436 "implementation", 

437 "framework", 

438 "library", 

439 "server", 

440 "client", 

441 "protocol", 

442 "algorithm", 

443 "data", 

444 "query", 

445 "schema", 

446 "endpoint", 

447 } 

448 

449 # Check keywords 

450 keyword_set = set(keywords) 

451 if keyword_set.intersection(technical_indicators): 

452 return True 

453 

454 # Check for technical entity types 

455 technical_entities = { 

456 "ORG", 

457 "PRODUCT", 

458 "LANGUAGE", 

459 } # Often technical in this context 

460 entity_labels = {label for _, label in entities} 

461 if entity_labels.intersection(technical_entities): 

462 return True 

463 

464 return False 

465 

466 def _calculate_complexity_score(self, doc: Doc) -> float: 

467 """Calculate query complexity based on linguistic features.""" 

468 if not doc: 

469 return 0.0 

470 

471 # Factors that contribute to complexity 

472 factors = { 

473 "length": min(len(doc) / 20, 1.0), # Longer queries are more complex 

474 "entities": min(len(doc.ents) / 5, 1.0), # More entities = more complex 

475 "noun_chunks": min( 

476 len(list(doc.noun_chunks)) / 5, 1.0 

477 ), # More concepts = more complex 

478 "question_words": min( 

479 sum( 

480 1 

481 for token in doc 

482 if token.text.lower() 

483 in {"what", "how", "why", "when", "who", "where", "which"} 

484 ) 

485 / 3, 

486 1.0, 

487 ), 

488 "dependency_depth": min(self._max_dependency_depth(doc) / 5, 1.0), 

489 } 

490 

491 # Weighted average 

492 weights = { 

493 "length": 0.2, 

494 "entities": 0.3, 

495 "noun_chunks": 0.2, 

496 "question_words": 0.15, 

497 "dependency_depth": 0.15, 

498 } 

499 

500 complexity = sum(factors[key] * weights[key] for key in factors) 

501 return min(complexity, 1.0) 

502 

503 def _max_dependency_depth(self, doc: Doc) -> int: 

504 """Calculate maximum dependency tree depth.""" 

505 max_depth = 0 

506 

507 def get_depth(token, current_depth=0): 

508 nonlocal max_depth 

509 max_depth = max(max_depth, current_depth) 

510 for child in token.children: 

511 get_depth(child, current_depth + 1) 

512 

513 for token in doc: 

514 if token.head == token: # Root token 

515 get_depth(token) 

516 

517 return max_depth 

518 

519 def _token_similarity_fallback( 

520 self, query_keywords: list[str], entity_text: str 

521 ) -> float: 

522 """Fallback similarity calculation when word vectors are unavailable.""" 

523 if not query_keywords: 

524 return 0.0 

525 

526 entity_words = set(entity_text.lower().split()) 

527 query_word_set = set(query_keywords) 

528 

529 # Simple Jaccard similarity 

530 intersection = query_word_set.intersection(entity_words) 

531 union = query_word_set.union(entity_words) 

532 

533 return len(intersection) / len(union) if union else 0.0 

534 

535 def clear_cache(self): 

536 """Clear analysis and similarity caches.""" 

537 self._analysis_cache.clear() 

538 self._similarity_cache.clear() 

539 logger.debug("Cleared spaCy analyzer caches") 

540 

541 def get_cache_stats(self) -> dict[str, int]: 

542 """Get cache statistics for monitoring.""" 

543 return { 

544 "analysis_cache_size": len(self._analysis_cache), 

545 "similarity_cache_size": len(self._similarity_cache), 

546 }