Coverage for src/qdrant_loader_mcp_server/cli.py: 51%

322 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-11 07:22 +0000

1"""CLI module for QDrant Loader MCP Server.""" 

2 

3import asyncio 

4import json 

5import logging 

6import os 

7import signal 

8import sys 

9import time 

10from pathlib import Path 

11 

12import click 

13from click.decorators import option 

14from click.types import Choice 

15from click.types import Path as ClickPath 

16from dotenv import load_dotenv 

17 

18from .config import Config 

19from .config_loader import load_config, redact_effective_config 

20from .mcp import MCPHandler 

21from .search.engine import SearchEngine 

22from .search.processor import QueryProcessor 

23from .transport import HTTPTransportHandler 

24from .utils import LoggingConfig, get_version 

25 

26# Suppress asyncio debug messages to reduce noise in logs. 

27logging.getLogger("asyncio").setLevel(logging.WARNING) 

28 

29 

30def _setup_logging(log_level: str, transport: str | None = None) -> None: 

31 """Set up logging configuration.""" 

32 try: 

33 # Force-disable console logging in stdio mode to avoid polluting stdout 

34 if transport and transport.lower() == "stdio": 

35 os.environ["MCP_DISABLE_CONSOLE_LOGGING"] = "true" 

36 

37 # Check if console logging is disabled via environment variable (after any override) 

38 disable_console_logging = ( 

39 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true" 

40 ) 

41 

42 # Reset any pre-existing handlers to prevent duplicate logs when setup() is 

43 # invoked implicitly during module imports before CLI config is applied. 

44 root_logger = logging.getLogger() 

45 for h in list(root_logger.handlers): 

46 try: 

47 root_logger.removeHandler(h) 

48 except Exception: 

49 pass 

50 

51 # Use reconfigure if available to avoid stacking handlers on repeated setup 

52 level = log_level.upper() 

53 if getattr(LoggingConfig, "reconfigure", None): # type: ignore[attr-defined] 

54 if getattr(LoggingConfig, "_initialized", False): # type: ignore[attr-defined] 

55 # Only switch file target (none in stdio; may be env provided) 

56 LoggingConfig.reconfigure(file=os.getenv("MCP_LOG_FILE")) # type: ignore[attr-defined] 

57 else: 

58 LoggingConfig.setup( 

59 level=level, 

60 format=("json" if disable_console_logging else "console"), 

61 ) 

62 else: 

63 # Force replace handlers on older versions 

64 logging.getLogger().handlers = [] 

65 LoggingConfig.setup( 

66 level=level, format=("json" if disable_console_logging else "console") 

67 ) 

68 except Exception as e: 

69 print(f"Failed to setup logging: {e}", file=sys.stderr) 

70 

71 

72async def read_stdin(): 

73 """Read from stdin asynchronously.""" 

74 loop = asyncio.get_running_loop() 

75 reader = asyncio.StreamReader() 

76 protocol = asyncio.StreamReaderProtocol(reader) 

77 await loop.connect_read_pipe(lambda: protocol, sys.stdin) 

78 return reader 

79 

80 

81async def shutdown( 

82 loop: asyncio.AbstractEventLoop, shutdown_event: asyncio.Event = None 

83): 

84 """Handle graceful shutdown.""" 

85 logger = LoggingConfig.get_logger(__name__) 

86 logger.info("Shutting down...") 

87 

88 # Only signal shutdown; let server/monitor handle draining and cleanup 

89 if shutdown_event: 

90 shutdown_event.set() 

91 

92 # Yield control so that other tasks (e.g., shutdown monitor, server) can react 

93 try: 

94 await asyncio.sleep(0) 

95 except asyncio.CancelledError: 

96 # If shutdown task is cancelled, just exit quietly 

97 return 

98 

99 logger.info("Shutdown signal dispatched") 

100 

101 

102async def start_http_server( 

103 config: Config, log_level: str, host: str, port: int, shutdown_event: asyncio.Event 

104): 

105 """Start MCP server with HTTP transport.""" 

106 logger = LoggingConfig.get_logger(__name__) 

107 search_engine = None 

108 

109 try: 

110 logger.info(f"Starting HTTP server on {host}:{port}") 

111 

112 # Initialize components 

113 search_engine = SearchEngine() 

114 query_processor = QueryProcessor(config.openai) 

115 mcp_handler = MCPHandler(search_engine, query_processor) 

116 

117 # Initialize search engine 

118 try: 

119 await search_engine.initialize(config.qdrant, config.openai, config.search) 

120 logger.info("Search engine initialized successfully") 

121 except Exception as e: 

122 logger.error("Failed to initialize search engine", exc_info=True) 

123 raise RuntimeError("Failed to initialize search engine") from e 

124 

125 # Create HTTP transport handler 

126 http_handler = HTTPTransportHandler(mcp_handler, host=host, port=port) 

127 

128 # Start the FastAPI server using uvicorn 

129 import uvicorn 

130 

131 uvicorn_config = uvicorn.Config( 

132 app=http_handler.app, 

133 host=host, 

134 port=port, 

135 log_level=log_level.lower(), 

136 access_log=log_level.upper() == "DEBUG", 

137 ) 

138 

139 server = uvicorn.Server(uvicorn_config) 

140 logger.info(f"HTTP MCP server ready at http://{host}:{port}/mcp") 

141 

142 # Create a task to monitor shutdown event 

143 async def shutdown_monitor(): 

144 try: 

145 await shutdown_event.wait() 

146 logger.info("Shutdown signal received, stopping HTTP server...") 

147 

148 # Signal uvicorn to stop gracefully 

149 server.should_exit = True 

150 

151 # Graceful drain logic: wait for in-flight requests to finish before forcing exit 

152 # Configurable timeouts via environment variables 

153 drain_timeout = float( 

154 os.getenv("MCP_HTTP_DRAIN_TIMEOUT_SECONDS", "10.0") 

155 ) 

156 max_shutdown_timeout = float( 

157 os.getenv("MCP_HTTP_SHUTDOWN_TIMEOUT_SECONDS", "30.0") 

158 ) 

159 

160 start_ts = time.monotonic() 

161 

162 # 1) Prioritize draining non-streaming requests quickly 

163 drained_non_stream = False 

164 try: 

165 while time.monotonic() - start_ts < drain_timeout: 

166 if not http_handler.has_inflight_non_streaming(): 

167 drained_non_stream = True 

168 logger.info( 

169 "Non-streaming requests drained; continuing shutdown" 

170 ) 

171 break 

172 await asyncio.sleep(0.1) 

173 except asyncio.CancelledError: 

174 logger.debug("Shutdown monitor cancelled during drain phase") 

175 return 

176 except Exception: 

177 # On any error during drain check, fall through to timeout-based force 

178 pass 

179 

180 if not drained_non_stream: 

181 logger.warning( 

182 f"Non-streaming requests still in flight after {drain_timeout}s; proceeding with shutdown" 

183 ) 

184 

185 # 2) Allow additional time (up to max_shutdown_timeout total) for all requests to complete 

186 total_deadline = start_ts + max_shutdown_timeout 

187 try: 

188 while time.monotonic() < total_deadline: 

189 counts = http_handler.get_inflight_request_counts() 

190 if counts.get("total", 0) == 0: 

191 logger.info( 

192 "All in-flight requests drained; completing shutdown without force" 

193 ) 

194 break 

195 await asyncio.sleep(0.2) 

196 except asyncio.CancelledError: 

197 logger.debug("Shutdown monitor cancelled during final drain phase") 

198 return 

199 except Exception: 

200 pass 

201 

202 # 3) If still not finished after the max timeout, force the server to exit 

203 if hasattr(server, "force_exit"): 

204 if time.monotonic() >= total_deadline: 

205 logger.warning( 

206 f"Forcing server exit after {max_shutdown_timeout}s shutdown timeout" 

207 ) 

208 server.force_exit = True 

209 else: 

210 logger.debug( 

211 "Server drained gracefully; force_exit not required" 

212 ) 

213 except asyncio.CancelledError: 

214 logger.debug("Shutdown monitor task cancelled") 

215 return 

216 

217 # Start shutdown monitor task 

218 monitor_task = asyncio.create_task(shutdown_monitor()) 

219 

220 try: 

221 # Run the server until shutdown 

222 await server.serve() 

223 except asyncio.CancelledError: 

224 logger.info("Server shutdown initiated") 

225 except Exception as e: 

226 if not shutdown_event.is_set(): 

227 logger.error(f"Server error: {e}", exc_info=True) 

228 else: 

229 logger.info(f"Server stopped during shutdown: {e}") 

230 finally: 

231 # Clean up the monitor task gracefully 

232 if monitor_task and not monitor_task.done(): 

233 logger.debug("Cleaning up shutdown monitor task") 

234 monitor_task.cancel() 

235 try: 

236 await asyncio.wait_for(monitor_task, timeout=2.0) 

237 except asyncio.CancelledError: 

238 logger.debug("Shutdown monitor task cancelled successfully") 

239 except TimeoutError: 

240 logger.warning("Shutdown monitor task cleanup timed out") 

241 except Exception as e: 

242 logger.debug(f"Shutdown monitor cleanup completed with: {e}") 

243 

244 except Exception as e: 

245 if not shutdown_event.is_set(): 

246 logger.error(f"Error in HTTP server: {e}", exc_info=True) 

247 raise 

248 finally: 

249 # Clean up search engine 

250 if search_engine: 

251 try: 

252 await search_engine.cleanup() 

253 logger.info("Search engine cleanup completed") 

254 except Exception as e: 

255 logger.error(f"Error during search engine cleanup: {e}", exc_info=True) 

256 

257 

258async def handle_stdio(config: Config, log_level: str): 

259 """Handle stdio communication with Cursor.""" 

260 logger = LoggingConfig.get_logger(__name__) 

261 

262 try: 

263 # Check if console logging is disabled 

264 disable_console_logging = ( 

265 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true" 

266 ) 

267 

268 if not disable_console_logging: 

269 logger.info("Setting up stdio handler...") 

270 

271 # Initialize components 

272 search_engine = SearchEngine() 

273 query_processor = QueryProcessor(config.openai) 

274 mcp_handler = MCPHandler(search_engine, query_processor) 

275 

276 # Initialize search engine 

277 try: 

278 await search_engine.initialize(config.qdrant, config.openai, config.search) 

279 if not disable_console_logging: 

280 logger.info("Search engine initialized successfully") 

281 except Exception as e: 

282 logger.error("Failed to initialize search engine", exc_info=True) 

283 raise RuntimeError("Failed to initialize search engine") from e 

284 

285 reader = await read_stdin() 

286 if not disable_console_logging: 

287 logger.info("Server ready to handle requests") 

288 

289 while True: 

290 try: 

291 # Read a line from stdin 

292 if not disable_console_logging: 

293 logger.debug("Waiting for input...") 

294 try: 

295 line = await reader.readline() 

296 if not line: 

297 if not disable_console_logging: 

298 logger.warning("No input received, breaking") 

299 break 

300 except asyncio.CancelledError: 

301 if not disable_console_logging: 

302 logger.info("Read operation cancelled during shutdown") 

303 break 

304 

305 # Log the raw input 

306 raw_input = line.decode().strip() 

307 if not disable_console_logging: 

308 logger.debug("Received raw input", raw_input=raw_input) 

309 

310 # Parse the request 

311 try: 

312 request = json.loads(raw_input) 

313 if not disable_console_logging: 

314 logger.debug("Parsed request", request=request) 

315 except json.JSONDecodeError as e: 

316 if not disable_console_logging: 

317 logger.error("Invalid JSON received", error=str(e)) 

318 # Send error response for invalid JSON 

319 response = { 

320 "jsonrpc": "2.0", 

321 "id": None, 

322 "error": { 

323 "code": -32700, 

324 "message": "Parse error", 

325 "data": f"Invalid JSON received: {str(e)}", 

326 }, 

327 } 

328 sys.stdout.write(json.dumps(response) + "\n") 

329 sys.stdout.flush() 

330 continue 

331 

332 # Validate request format 

333 if not isinstance(request, dict): 

334 if not disable_console_logging: 

335 logger.error("Request must be a JSON object") 

336 response = { 

337 "jsonrpc": "2.0", 

338 "id": None, 

339 "error": { 

340 "code": -32600, 

341 "message": "Invalid Request", 

342 "data": "Request must be a JSON object", 

343 }, 

344 } 

345 sys.stdout.write(json.dumps(response) + "\n") 

346 sys.stdout.flush() 

347 continue 

348 

349 if "jsonrpc" not in request or request["jsonrpc"] != "2.0": 

350 if not disable_console_logging: 

351 logger.error("Invalid JSON-RPC version") 

352 response = { 

353 "jsonrpc": "2.0", 

354 "id": request.get("id"), 

355 "error": { 

356 "code": -32600, 

357 "message": "Invalid Request", 

358 "data": "Invalid JSON-RPC version", 

359 }, 

360 } 

361 sys.stdout.write(json.dumps(response) + "\n") 

362 sys.stdout.flush() 

363 continue 

364 

365 # Process the request 

366 try: 

367 response = await mcp_handler.handle_request(request) 

368 if not disable_console_logging: 

369 logger.debug("Sending response", response=response) 

370 # Only write to stdout if response is not empty (not a notification) 

371 if response: 

372 sys.stdout.write(json.dumps(response) + "\n") 

373 sys.stdout.flush() 

374 except Exception as e: 

375 if not disable_console_logging: 

376 logger.error("Error processing request", exc_info=True) 

377 response = { 

378 "jsonrpc": "2.0", 

379 "id": request.get("id"), 

380 "error": { 

381 "code": -32603, 

382 "message": "Internal error", 

383 "data": str(e), 

384 }, 

385 } 

386 sys.stdout.write(json.dumps(response) + "\n") 

387 sys.stdout.flush() 

388 

389 except asyncio.CancelledError: 

390 if not disable_console_logging: 

391 logger.info("Request handling cancelled during shutdown") 

392 break 

393 except Exception: 

394 if not disable_console_logging: 

395 logger.error("Error handling request", exc_info=True) 

396 continue 

397 

398 # Cleanup 

399 await search_engine.cleanup() 

400 

401 except Exception: 

402 if not disable_console_logging: 

403 logger.error("Error in stdio handler", exc_info=True) 

404 raise 

405 

406 

407@click.command(name="mcp-qdrant-loader") 

408@option( 

409 "--log-level", 

410 type=Choice( 

411 ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False 

412 ), 

413 default="INFO", 

414 help="Set the logging level.", 

415) 

416# Hidden option to print effective config (redacts secrets) 

417@option( 

418 "--print-config", 

419 is_flag=True, 

420 default=False, 

421 help="Print the effective configuration (secrets redacted) and exit.", 

422) 

423@option( 

424 "--config", 

425 type=ClickPath(exists=True, path_type=Path), 

426 help="Path to configuration file.", 

427) 

428@option( 

429 "--transport", 

430 type=Choice(["stdio", "http"], case_sensitive=False), 

431 default="stdio", 

432 help="Transport protocol to use (stdio for JSON-RPC over stdin/stdout, http for HTTP with SSE)", 

433) 

434@option( 

435 "--host", 

436 type=str, 

437 default="127.0.0.1", 

438 help="Host to bind HTTP server to (only used with --transport http)", 

439) 

440@option( 

441 "--port", 

442 type=int, 

443 default=8080, 

444 help="Port to bind HTTP server to (only used with --transport http)", 

445) 

446@option( 

447 "--env", 

448 type=ClickPath(exists=True, path_type=Path), 

449 help="Path to .env file to load environment variables from", 

450) 

451@click.version_option( 

452 version=get_version(), 

453 message="QDrant Loader MCP Server v%(version)s", 

454) 

455def cli( 

456 log_level: str = "INFO", 

457 config: Path | None = None, 

458 transport: str = "stdio", 

459 host: str = "127.0.0.1", 

460 port: int = 8080, 

461 env: Path | None = None, 

462 print_config: bool = False, 

463) -> None: 

464 """QDrant Loader MCP Server. 

465 

466 A Model Context Protocol (MCP) server that provides RAG capabilities 

467 to Cursor and other LLM applications using Qdrant vector database. 

468 

469 The server supports both stdio (JSON-RPC) and HTTP (with SSE) transports 

470 for maximum compatibility with different MCP clients. 

471 

472 Environment Variables: 

473 QDRANT_URL: URL of your QDrant instance (required) 

474 QDRANT_API_KEY: API key for QDrant authentication 

475 QDRANT_COLLECTION_NAME: Name of the collection to use (default: "documents") 

476 OPENAI_API_KEY: OpenAI API key for embeddings (required) 

477 MCP_DISABLE_CONSOLE_LOGGING: Set to "true" to disable console logging 

478 

479 Examples: 

480 # Start with stdio transport (default, for Cursor/Claude Desktop) 

481 mcp-qdrant-loader 

482 

483 # Start with HTTP transport (for web clients) 

484 mcp-qdrant-loader --transport http --port 8080 

485 

486 # Start with environment variables from .env file 

487 mcp-qdrant-loader --transport http --env /path/to/.env 

488 

489 # Start with debug logging 

490 mcp-qdrant-loader --log-level DEBUG --transport http 

491 

492 # Show help 

493 mcp-qdrant-loader --help 

494 

495 # Show version 

496 mcp-qdrant-loader --version 

497 """ 

498 loop = None 

499 try: 

500 # Load environment variables from .env file if specified 

501 if env: 

502 load_dotenv(env) 

503 

504 # Setup logging (force-disable console logging in stdio transport) 

505 _setup_logging(log_level, transport) 

506 

507 # Log env file load after logging is configured to avoid duplicate handler setup 

508 if env: 

509 LoggingConfig.get_logger(__name__).info( 

510 "Loaded environment variables", env=str(env) 

511 ) 

512 

513 # If a config file was provided, propagate it via MCP_CONFIG so that 

514 # any internal callers that resolve config without CLI context can find it. 

515 if config is not None: 

516 try: 

517 os.environ["MCP_CONFIG"] = str(config) 

518 except Exception: 

519 # Best-effort; continue without blocking startup 

520 pass 

521 

522 # Initialize configuration (file/env precedence) 

523 config_obj, effective_cfg, used_file = load_config(config) 

524 

525 if print_config: 

526 redacted = redact_effective_config(effective_cfg) 

527 click.echo(json.dumps(redacted, indent=2)) 

528 return 

529 

530 # Create and set the event loop 

531 loop = asyncio.new_event_loop() 

532 asyncio.set_event_loop(loop) 

533 

534 # Create shutdown event for coordinating graceful shutdown 

535 shutdown_event = asyncio.Event() 

536 shutdown_task = None 

537 

538 # Set up signal handlers with shutdown event 

539 def signal_handler(): 

540 # Schedule shutdown on the explicit loop for clarity and correctness 

541 nonlocal shutdown_task 

542 if shutdown_task is None: 

543 shutdown_task = loop.create_task(shutdown(loop, shutdown_event)) 

544 

545 for sig in (signal.SIGTERM, signal.SIGINT): 

546 loop.add_signal_handler(sig, signal_handler) 

547 

548 # Start the appropriate transport handler 

549 if transport.lower() == "stdio": 

550 loop.run_until_complete(handle_stdio(config_obj, log_level)) 

551 elif transport.lower() == "http": 

552 loop.run_until_complete( 

553 start_http_server(config_obj, log_level, host, port, shutdown_event) 

554 ) 

555 else: 

556 raise ValueError(f"Unsupported transport: {transport}") 

557 except Exception: 

558 logger = LoggingConfig.get_logger(__name__) 

559 logger.error("Error in main", exc_info=True) 

560 sys.exit(1) 

561 finally: 

562 if loop: 

563 try: 

564 # First, wait for the shutdown task if it exists 

565 if ( 

566 "shutdown_task" in locals() 

567 and shutdown_task is not None 

568 and not shutdown_task.done() 

569 ): 

570 try: 

571 logger = LoggingConfig.get_logger(__name__) 

572 logger.debug("Waiting for shutdown task to complete...") 

573 loop.run_until_complete( 

574 asyncio.wait_for(shutdown_task, timeout=5.0) 

575 ) 

576 logger.debug("Shutdown task completed successfully") 

577 except TimeoutError: 

578 logger = LoggingConfig.get_logger(__name__) 

579 logger.warning("Shutdown task timed out, cancelling...") 

580 shutdown_task.cancel() 

581 try: 

582 loop.run_until_complete(shutdown_task) 

583 except asyncio.CancelledError: 

584 logger.debug("Shutdown task cancelled successfully") 

585 except Exception as e: 

586 logger = LoggingConfig.get_logger(__name__) 

587 logger.debug(f"Shutdown task completed with: {e}") 

588 

589 # Then cancel any remaining tasks (except completed shutdown task) 

590 def _cancel_all_pending_tasks(): 

591 """Cancel tasks safely without circular dependencies.""" 

592 all_tasks = list(asyncio.all_tasks(loop)) 

593 if not all_tasks: 

594 return 

595 

596 # Cancel all tasks except the completed shutdown task 

597 cancelled_tasks = [] 

598 for task in all_tasks: 

599 if not task.done() and task is not shutdown_task: 

600 task.cancel() 

601 cancelled_tasks.append(task) 

602 

603 # Don't await gather to avoid recursion - just let them finish on their own 

604 # The loop will handle the cleanup when it closes 

605 if cancelled_tasks: 

606 logger = LoggingConfig.get_logger(__name__) 

607 logger.info( 

608 f"Cancelled {len(cancelled_tasks)} remaining tasks for cleanup" 

609 ) 

610 

611 _cancel_all_pending_tasks() 

612 except Exception: 

613 logger = LoggingConfig.get_logger(__name__) 

614 logger.error("Error during final cleanup", exc_info=True) 

615 finally: 

616 loop.close() 

617 logger = LoggingConfig.get_logger(__name__) 

618 logger.info("Server shutdown complete") 

619 

620 

621if __name__ == "__main__": 

622 cli()