Coverage for src/qdrant_loader_mcp_server/cli.py: 51%

318 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:06 +0000

1"""CLI module for QDrant Loader MCP Server.""" 

2 

3import asyncio 

4import json 

5import logging 

6import os 

7import signal 

8import sys 

9import time 

10from pathlib import Path 

11 

12import click 

13from click.decorators import option 

14from click.types import Choice 

15from click.types import Path as ClickPath 

16from dotenv import load_dotenv 

17 

18from .config import Config 

19from .config_loader import load_config, redact_effective_config 

20from .mcp import MCPHandler 

21from .search.engine import SearchEngine 

22from .search.processor import QueryProcessor 

23from .transport import HTTPTransportHandler 

24from .utils import LoggingConfig, get_version 

25 

26# Suppress asyncio debug messages to reduce noise in logs. 

27logging.getLogger("asyncio").setLevel(logging.WARNING) 

28 

29 

30def _setup_logging(log_level: str, transport: str | None = None) -> None: 

31 """Set up logging configuration.""" 

32 try: 

33 # Force-disable console logging in stdio mode to avoid polluting stdout 

34 if transport and transport.lower() == "stdio": 

35 os.environ["MCP_DISABLE_CONSOLE_LOGGING"] = "true" 

36 

37 # Check if console logging is disabled via environment variable (after any override) 

38 disable_console_logging = ( 

39 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true" 

40 ) 

41 

42 # Reset any pre-existing handlers to prevent duplicate logs when setup() is 

43 # invoked implicitly during module imports before CLI config is applied. 

44 root_logger = logging.getLogger() 

45 for h in list(root_logger.handlers): 

46 try: 

47 root_logger.removeHandler(h) 

48 except Exception: 

49 pass 

50 

51 if not disable_console_logging: 

52 # Console format goes to stderr via our logging config 

53 LoggingConfig.setup(level=log_level.upper(), format="console") 

54 else: 

55 LoggingConfig.setup(level=log_level.upper(), format="json") 

56 except Exception as e: 

57 print(f"Failed to setup logging: {e}", file=sys.stderr) 

58 

59 

60async def read_stdin(): 

61 """Read from stdin asynchronously.""" 

62 loop = asyncio.get_running_loop() 

63 reader = asyncio.StreamReader() 

64 protocol = asyncio.StreamReaderProtocol(reader) 

65 await loop.connect_read_pipe(lambda: protocol, sys.stdin) 

66 return reader 

67 

68 

69async def shutdown( 

70 loop: asyncio.AbstractEventLoop, shutdown_event: asyncio.Event = None 

71): 

72 """Handle graceful shutdown.""" 

73 logger = LoggingConfig.get_logger(__name__) 

74 logger.info("Shutting down...") 

75 

76 # Only signal shutdown; let server/monitor handle draining and cleanup 

77 if shutdown_event: 

78 shutdown_event.set() 

79 

80 # Yield control so that other tasks (e.g., shutdown monitor, server) can react 

81 try: 

82 await asyncio.sleep(0) 

83 except asyncio.CancelledError: 

84 # If shutdown task is cancelled, just exit quietly 

85 return 

86 

87 logger.info("Shutdown signal dispatched") 

88 

89 

90async def start_http_server( 

91 config: Config, log_level: str, host: str, port: int, shutdown_event: asyncio.Event 

92): 

93 """Start MCP server with HTTP transport.""" 

94 logger = LoggingConfig.get_logger(__name__) 

95 search_engine = None 

96 

97 try: 

98 logger.info(f"Starting HTTP server on {host}:{port}") 

99 

100 # Initialize components 

101 search_engine = SearchEngine() 

102 query_processor = QueryProcessor(config.openai) 

103 mcp_handler = MCPHandler(search_engine, query_processor) 

104 

105 # Initialize search engine 

106 try: 

107 await search_engine.initialize(config.qdrant, config.openai, config.search) 

108 logger.info("Search engine initialized successfully") 

109 except Exception as e: 

110 logger.error("Failed to initialize search engine", exc_info=True) 

111 raise RuntimeError("Failed to initialize search engine") from e 

112 

113 # Create HTTP transport handler 

114 http_handler = HTTPTransportHandler(mcp_handler, host=host, port=port) 

115 

116 # Start the FastAPI server using uvicorn 

117 import uvicorn 

118 

119 uvicorn_config = uvicorn.Config( 

120 app=http_handler.app, 

121 host=host, 

122 port=port, 

123 log_level=log_level.lower(), 

124 access_log=log_level.upper() == "DEBUG", 

125 ) 

126 

127 server = uvicorn.Server(uvicorn_config) 

128 logger.info(f"HTTP MCP server ready at http://{host}:{port}/mcp") 

129 

130 # Create a task to monitor shutdown event 

131 async def shutdown_monitor(): 

132 try: 

133 await shutdown_event.wait() 

134 logger.info("Shutdown signal received, stopping HTTP server...") 

135 

136 # Signal uvicorn to stop gracefully 

137 server.should_exit = True 

138 

139 # Graceful drain logic: wait for in-flight requests to finish before forcing exit 

140 # Configurable timeouts via environment variables 

141 drain_timeout = float( 

142 os.getenv("MCP_HTTP_DRAIN_TIMEOUT_SECONDS", "10.0") 

143 ) 

144 max_shutdown_timeout = float( 

145 os.getenv("MCP_HTTP_SHUTDOWN_TIMEOUT_SECONDS", "30.0") 

146 ) 

147 

148 start_ts = time.monotonic() 

149 

150 # 1) Prioritize draining non-streaming requests quickly 

151 drained_non_stream = False 

152 try: 

153 while time.monotonic() - start_ts < drain_timeout: 

154 if not http_handler.has_inflight_non_streaming(): 

155 drained_non_stream = True 

156 logger.info( 

157 "Non-streaming requests drained; continuing shutdown" 

158 ) 

159 break 

160 await asyncio.sleep(0.1) 

161 except asyncio.CancelledError: 

162 logger.debug("Shutdown monitor cancelled during drain phase") 

163 return 

164 except Exception: 

165 # On any error during drain check, fall through to timeout-based force 

166 pass 

167 

168 if not drained_non_stream: 

169 logger.warning( 

170 f"Non-streaming requests still in flight after {drain_timeout}s; proceeding with shutdown" 

171 ) 

172 

173 # 2) Allow additional time (up to max_shutdown_timeout total) for all requests to complete 

174 total_deadline = start_ts + max_shutdown_timeout 

175 try: 

176 while time.monotonic() < total_deadline: 

177 counts = http_handler.get_inflight_request_counts() 

178 if counts.get("total", 0) == 0: 

179 logger.info( 

180 "All in-flight requests drained; completing shutdown without force" 

181 ) 

182 break 

183 await asyncio.sleep(0.2) 

184 except asyncio.CancelledError: 

185 logger.debug("Shutdown monitor cancelled during final drain phase") 

186 return 

187 except Exception: 

188 pass 

189 

190 # 3) If still not finished after the max timeout, force the server to exit 

191 if hasattr(server, "force_exit"): 

192 if time.monotonic() >= total_deadline: 

193 logger.warning( 

194 f"Forcing server exit after {max_shutdown_timeout}s shutdown timeout" 

195 ) 

196 server.force_exit = True 

197 else: 

198 logger.debug( 

199 "Server drained gracefully; force_exit not required" 

200 ) 

201 except asyncio.CancelledError: 

202 logger.debug("Shutdown monitor task cancelled") 

203 return 

204 

205 # Start shutdown monitor task 

206 monitor_task = asyncio.create_task(shutdown_monitor()) 

207 

208 try: 

209 # Run the server until shutdown 

210 await server.serve() 

211 except asyncio.CancelledError: 

212 logger.info("Server shutdown initiated") 

213 except Exception as e: 

214 if not shutdown_event.is_set(): 

215 logger.error(f"Server error: {e}", exc_info=True) 

216 else: 

217 logger.info(f"Server stopped during shutdown: {e}") 

218 finally: 

219 # Clean up the monitor task gracefully 

220 if monitor_task and not monitor_task.done(): 

221 logger.debug("Cleaning up shutdown monitor task") 

222 monitor_task.cancel() 

223 try: 

224 await asyncio.wait_for(monitor_task, timeout=2.0) 

225 except asyncio.CancelledError: 

226 logger.debug("Shutdown monitor task cancelled successfully") 

227 except TimeoutError: 

228 logger.warning("Shutdown monitor task cleanup timed out") 

229 except Exception as e: 

230 logger.debug(f"Shutdown monitor cleanup completed with: {e}") 

231 

232 except Exception as e: 

233 if not shutdown_event.is_set(): 

234 logger.error(f"Error in HTTP server: {e}", exc_info=True) 

235 raise 

236 finally: 

237 # Clean up search engine 

238 if search_engine: 

239 try: 

240 await search_engine.cleanup() 

241 logger.info("Search engine cleanup completed") 

242 except Exception as e: 

243 logger.error(f"Error during search engine cleanup: {e}", exc_info=True) 

244 

245 

246async def handle_stdio(config: Config, log_level: str): 

247 """Handle stdio communication with Cursor.""" 

248 logger = LoggingConfig.get_logger(__name__) 

249 

250 try: 

251 # Check if console logging is disabled 

252 disable_console_logging = ( 

253 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true" 

254 ) 

255 

256 if not disable_console_logging: 

257 logger.info("Setting up stdio handler...") 

258 

259 # Initialize components 

260 search_engine = SearchEngine() 

261 query_processor = QueryProcessor(config.openai) 

262 mcp_handler = MCPHandler(search_engine, query_processor) 

263 

264 # Initialize search engine 

265 try: 

266 await search_engine.initialize(config.qdrant, config.openai, config.search) 

267 if not disable_console_logging: 

268 logger.info("Search engine initialized successfully") 

269 except Exception as e: 

270 logger.error("Failed to initialize search engine", exc_info=True) 

271 raise RuntimeError("Failed to initialize search engine") from e 

272 

273 reader = await read_stdin() 

274 if not disable_console_logging: 

275 logger.info("Server ready to handle requests") 

276 

277 while True: 

278 try: 

279 # Read a line from stdin 

280 if not disable_console_logging: 

281 logger.debug("Waiting for input...") 

282 try: 

283 line = await reader.readline() 

284 if not line: 

285 if not disable_console_logging: 

286 logger.warning("No input received, breaking") 

287 break 

288 except asyncio.CancelledError: 

289 if not disable_console_logging: 

290 logger.info("Read operation cancelled during shutdown") 

291 break 

292 

293 # Log the raw input 

294 raw_input = line.decode().strip() 

295 if not disable_console_logging: 

296 logger.debug("Received raw input", raw_input=raw_input) 

297 

298 # Parse the request 

299 try: 

300 request = json.loads(raw_input) 

301 if not disable_console_logging: 

302 logger.debug("Parsed request", request=request) 

303 except json.JSONDecodeError as e: 

304 if not disable_console_logging: 

305 logger.error("Invalid JSON received", error=str(e)) 

306 # Send error response for invalid JSON 

307 response = { 

308 "jsonrpc": "2.0", 

309 "id": None, 

310 "error": { 

311 "code": -32700, 

312 "message": "Parse error", 

313 "data": f"Invalid JSON received: {str(e)}", 

314 }, 

315 } 

316 sys.stdout.write(json.dumps(response) + "\n") 

317 sys.stdout.flush() 

318 continue 

319 

320 # Validate request format 

321 if not isinstance(request, dict): 

322 if not disable_console_logging: 

323 logger.error("Request must be a JSON object") 

324 response = { 

325 "jsonrpc": "2.0", 

326 "id": None, 

327 "error": { 

328 "code": -32600, 

329 "message": "Invalid Request", 

330 "data": "Request must be a JSON object", 

331 }, 

332 } 

333 sys.stdout.write(json.dumps(response) + "\n") 

334 sys.stdout.flush() 

335 continue 

336 

337 if "jsonrpc" not in request or request["jsonrpc"] != "2.0": 

338 if not disable_console_logging: 

339 logger.error("Invalid JSON-RPC version") 

340 response = { 

341 "jsonrpc": "2.0", 

342 "id": request.get("id"), 

343 "error": { 

344 "code": -32600, 

345 "message": "Invalid Request", 

346 "data": "Invalid JSON-RPC version", 

347 }, 

348 } 

349 sys.stdout.write(json.dumps(response) + "\n") 

350 sys.stdout.flush() 

351 continue 

352 

353 # Process the request 

354 try: 

355 response = await mcp_handler.handle_request(request) 

356 if not disable_console_logging: 

357 logger.debug("Sending response", response=response) 

358 # Only write to stdout if response is not empty (not a notification) 

359 if response: 

360 sys.stdout.write(json.dumps(response) + "\n") 

361 sys.stdout.flush() 

362 except Exception as e: 

363 if not disable_console_logging: 

364 logger.error("Error processing request", exc_info=True) 

365 response = { 

366 "jsonrpc": "2.0", 

367 "id": request.get("id"), 

368 "error": { 

369 "code": -32603, 

370 "message": "Internal error", 

371 "data": str(e), 

372 }, 

373 } 

374 sys.stdout.write(json.dumps(response) + "\n") 

375 sys.stdout.flush() 

376 

377 except asyncio.CancelledError: 

378 if not disable_console_logging: 

379 logger.info("Request handling cancelled during shutdown") 

380 break 

381 except Exception: 

382 if not disable_console_logging: 

383 logger.error("Error handling request", exc_info=True) 

384 continue 

385 

386 # Cleanup 

387 await search_engine.cleanup() 

388 

389 except Exception: 

390 if not disable_console_logging: 

391 logger.error("Error in stdio handler", exc_info=True) 

392 raise 

393 

394 

395@click.command(name="mcp-qdrant-loader") 

396@option( 

397 "--log-level", 

398 type=Choice( 

399 ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False 

400 ), 

401 default="INFO", 

402 help="Set the logging level.", 

403) 

404# Hidden option to print effective config (redacts secrets) 

405@option( 

406 "--print-config", 

407 is_flag=True, 

408 default=False, 

409 help="Print the effective configuration (secrets redacted) and exit.", 

410) 

411@option( 

412 "--config", 

413 type=ClickPath(exists=True, path_type=Path), 

414 help="Path to configuration file.", 

415) 

416@option( 

417 "--transport", 

418 type=Choice(["stdio", "http"], case_sensitive=False), 

419 default="stdio", 

420 help="Transport protocol to use (stdio for JSON-RPC over stdin/stdout, http for HTTP with SSE)", 

421) 

422@option( 

423 "--host", 

424 type=str, 

425 default="127.0.0.1", 

426 help="Host to bind HTTP server to (only used with --transport http)", 

427) 

428@option( 

429 "--port", 

430 type=int, 

431 default=8080, 

432 help="Port to bind HTTP server to (only used with --transport http)", 

433) 

434@option( 

435 "--env", 

436 type=ClickPath(exists=True, path_type=Path), 

437 help="Path to .env file to load environment variables from", 

438) 

439@click.version_option( 

440 version=get_version(), 

441 message="QDrant Loader MCP Server v%(version)s", 

442) 

443def cli( 

444 log_level: str = "INFO", 

445 config: Path | None = None, 

446 transport: str = "stdio", 

447 host: str = "127.0.0.1", 

448 port: int = 8080, 

449 env: Path | None = None, 

450 print_config: bool = False, 

451) -> None: 

452 """QDrant Loader MCP Server. 

453 

454 A Model Context Protocol (MCP) server that provides RAG capabilities 

455 to Cursor and other LLM applications using Qdrant vector database. 

456 

457 The server supports both stdio (JSON-RPC) and HTTP (with SSE) transports 

458 for maximum compatibility with different MCP clients. 

459 

460 Environment Variables: 

461 QDRANT_URL: URL of your QDrant instance (required) 

462 QDRANT_API_KEY: API key for QDrant authentication 

463 QDRANT_COLLECTION_NAME: Name of the collection to use (default: "documents") 

464 OPENAI_API_KEY: OpenAI API key for embeddings (required) 

465 MCP_DISABLE_CONSOLE_LOGGING: Set to "true" to disable console logging 

466 

467 Examples: 

468 # Start with stdio transport (default, for Cursor/Claude Desktop) 

469 mcp-qdrant-loader 

470 

471 # Start with HTTP transport (for web clients) 

472 mcp-qdrant-loader --transport http --port 8080 

473 

474 # Start with environment variables from .env file 

475 mcp-qdrant-loader --transport http --env /path/to/.env 

476 

477 # Start with debug logging 

478 mcp-qdrant-loader --log-level DEBUG --transport http 

479 

480 # Show help 

481 mcp-qdrant-loader --help 

482 

483 # Show version 

484 mcp-qdrant-loader --version 

485 """ 

486 loop = None 

487 try: 

488 # Load environment variables from .env file if specified 

489 if env: 

490 load_dotenv(env) 

491 

492 # Setup logging (force-disable console logging in stdio transport) 

493 _setup_logging(log_level, transport) 

494 

495 # Log env file load after logging is configured to avoid duplicate handler setup 

496 if env: 

497 LoggingConfig.get_logger(__name__).info( 

498 "Loaded environment variables", env=str(env) 

499 ) 

500 

501 # If a config file was provided, propagate it via MCP_CONFIG so that 

502 # any internal callers that resolve config without CLI context can find it. 

503 if config is not None: 

504 try: 

505 os.environ["MCP_CONFIG"] = str(config) 

506 except Exception: 

507 # Best-effort; continue without blocking startup 

508 pass 

509 

510 # Initialize configuration (file/env precedence) 

511 config_obj, effective_cfg, used_file = load_config(config) 

512 

513 if print_config: 

514 redacted = redact_effective_config(effective_cfg) 

515 click.echo(json.dumps(redacted, indent=2)) 

516 return 

517 

518 # Create and set the event loop 

519 loop = asyncio.new_event_loop() 

520 asyncio.set_event_loop(loop) 

521 

522 # Create shutdown event for coordinating graceful shutdown 

523 shutdown_event = asyncio.Event() 

524 shutdown_task = None 

525 

526 # Set up signal handlers with shutdown event 

527 def signal_handler(): 

528 # Schedule shutdown on the explicit loop for clarity and correctness 

529 nonlocal shutdown_task 

530 if shutdown_task is None: 

531 shutdown_task = loop.create_task(shutdown(loop, shutdown_event)) 

532 

533 for sig in (signal.SIGTERM, signal.SIGINT): 

534 loop.add_signal_handler(sig, signal_handler) 

535 

536 # Start the appropriate transport handler 

537 if transport.lower() == "stdio": 

538 loop.run_until_complete(handle_stdio(config_obj, log_level)) 

539 elif transport.lower() == "http": 

540 loop.run_until_complete( 

541 start_http_server(config_obj, log_level, host, port, shutdown_event) 

542 ) 

543 else: 

544 raise ValueError(f"Unsupported transport: {transport}") 

545 except Exception: 

546 logger = LoggingConfig.get_logger(__name__) 

547 logger.error("Error in main", exc_info=True) 

548 sys.exit(1) 

549 finally: 

550 if loop: 

551 try: 

552 # First, wait for the shutdown task if it exists 

553 if ( 

554 "shutdown_task" in locals() 

555 and shutdown_task is not None 

556 and not shutdown_task.done() 

557 ): 

558 try: 

559 logger = LoggingConfig.get_logger(__name__) 

560 logger.debug("Waiting for shutdown task to complete...") 

561 loop.run_until_complete( 

562 asyncio.wait_for(shutdown_task, timeout=5.0) 

563 ) 

564 logger.debug("Shutdown task completed successfully") 

565 except TimeoutError: 

566 logger = LoggingConfig.get_logger(__name__) 

567 logger.warning("Shutdown task timed out, cancelling...") 

568 shutdown_task.cancel() 

569 try: 

570 loop.run_until_complete(shutdown_task) 

571 except asyncio.CancelledError: 

572 logger.debug("Shutdown task cancelled successfully") 

573 except Exception as e: 

574 logger = LoggingConfig.get_logger(__name__) 

575 logger.debug(f"Shutdown task completed with: {e}") 

576 

577 # Then cancel any remaining tasks (except completed shutdown task) 

578 def _cancel_all_pending_tasks(): 

579 """Cancel tasks safely without circular dependencies.""" 

580 all_tasks = list(asyncio.all_tasks(loop)) 

581 if not all_tasks: 

582 return 

583 

584 # Cancel all tasks except the completed shutdown task 

585 cancelled_tasks = [] 

586 for task in all_tasks: 

587 if not task.done() and task is not shutdown_task: 

588 task.cancel() 

589 cancelled_tasks.append(task) 

590 

591 # Don't await gather to avoid recursion - just let them finish on their own 

592 # The loop will handle the cleanup when it closes 

593 if cancelled_tasks: 

594 logger = LoggingConfig.get_logger(__name__) 

595 logger.info( 

596 f"Cancelled {len(cancelled_tasks)} remaining tasks for cleanup" 

597 ) 

598 

599 _cancel_all_pending_tasks() 

600 except Exception: 

601 logger = LoggingConfig.get_logger(__name__) 

602 logger.error("Error during final cleanup", exc_info=True) 

603 finally: 

604 loop.close() 

605 logger = LoggingConfig.get_logger(__name__) 

606 logger.info("Server shutdown complete") 

607 

608 

609if __name__ == "__main__": 

610 cli()