Coverage for src/qdrant_loader_mcp_server/cli.py: 51%

300 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-13 09:20 +0000

1"""CLI module for QDrant Loader MCP Server.""" 

2 

3import asyncio 

4import json 

5import logging 

6import os 

7import signal 

8import sys 

9import time 

10from pathlib import Path 

11 

12import click 

13from click.decorators import option 

14from click.types import Choice 

15from click.types import Path as ClickPath 

16from dotenv import load_dotenv 

17 

18from .config import Config 

19from .mcp import MCPHandler 

20from .search.engine import SearchEngine 

21from .search.processor import QueryProcessor 

22from .transport import HTTPTransportHandler 

23from .utils import LoggingConfig, get_version 

24 

25# Suppress asyncio debug messages to reduce noise in logs. 

26logging.getLogger("asyncio").setLevel(logging.WARNING) 

27 

28 

29def _setup_logging(log_level: str, transport: str | None = None) -> None: 

30 """Set up logging configuration.""" 

31 try: 

32 # Force-disable console logging in stdio mode to avoid polluting stdout 

33 if transport and transport.lower() == "stdio": 

34 os.environ["MCP_DISABLE_CONSOLE_LOGGING"] = "true" 

35 

36 # Check if console logging is disabled via environment variable (after any override) 

37 disable_console_logging = ( 

38 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true" 

39 ) 

40 

41 if not disable_console_logging: 

42 # Console format goes to stderr via our logging config 

43 LoggingConfig.setup(level=log_level.upper(), format="console") 

44 else: 

45 LoggingConfig.setup(level=log_level.upper(), format="json") 

46 except Exception as e: 

47 print(f"Failed to setup logging: {e}", file=sys.stderr) 

48 

49 

50async def read_stdin(): 

51 """Read from stdin asynchronously.""" 

52 loop = asyncio.get_running_loop() 

53 reader = asyncio.StreamReader() 

54 protocol = asyncio.StreamReaderProtocol(reader) 

55 await loop.connect_read_pipe(lambda: protocol, sys.stdin) 

56 return reader 

57 

58 

59async def shutdown( 

60 loop: asyncio.AbstractEventLoop, shutdown_event: asyncio.Event = None 

61): 

62 """Handle graceful shutdown.""" 

63 logger = LoggingConfig.get_logger(__name__) 

64 logger.info("Shutting down...") 

65 

66 # Only signal shutdown; let server/monitor handle draining and cleanup 

67 if shutdown_event: 

68 shutdown_event.set() 

69 

70 # Yield control so that other tasks (e.g., shutdown monitor, server) can react 

71 try: 

72 await asyncio.sleep(0) 

73 except asyncio.CancelledError: 

74 # If shutdown task is cancelled, just exit quietly 

75 return 

76 

77 logger.info("Shutdown signal dispatched") 

78 

79 

80async def start_http_server( 

81 config: Config, log_level: str, host: str, port: int, shutdown_event: asyncio.Event 

82): 

83 """Start MCP server with HTTP transport.""" 

84 logger = LoggingConfig.get_logger(__name__) 

85 search_engine = None 

86 

87 try: 

88 logger.info(f"Starting HTTP server on {host}:{port}") 

89 

90 # Initialize components 

91 search_engine = SearchEngine() 

92 query_processor = QueryProcessor(config.openai) 

93 mcp_handler = MCPHandler(search_engine, query_processor) 

94 

95 # Initialize search engine 

96 try: 

97 await search_engine.initialize(config.qdrant, config.openai, config.search) 

98 logger.info("Search engine initialized successfully") 

99 except Exception as e: 

100 logger.error("Failed to initialize search engine", exc_info=True) 

101 raise RuntimeError("Failed to initialize search engine") from e 

102 

103 # Create HTTP transport handler 

104 http_handler = HTTPTransportHandler(mcp_handler, host=host, port=port) 

105 

106 # Start the FastAPI server using uvicorn 

107 import uvicorn 

108 

109 uvicorn_config = uvicorn.Config( 

110 app=http_handler.app, 

111 host=host, 

112 port=port, 

113 log_level=log_level.lower(), 

114 access_log=log_level.upper() == "DEBUG", 

115 ) 

116 

117 server = uvicorn.Server(uvicorn_config) 

118 logger.info(f"HTTP MCP server ready at http://{host}:{port}/mcp") 

119 

120 # Create a task to monitor shutdown event 

121 async def shutdown_monitor(): 

122 try: 

123 await shutdown_event.wait() 

124 logger.info("Shutdown signal received, stopping HTTP server...") 

125 

126 # Signal uvicorn to stop gracefully 

127 server.should_exit = True 

128 

129 # Graceful drain logic: wait for in-flight requests to finish before forcing exit 

130 # Configurable timeouts via environment variables 

131 drain_timeout = float(os.getenv("MCP_HTTP_DRAIN_TIMEOUT_SECONDS", "10.0")) 

132 max_shutdown_timeout = float( 

133 os.getenv("MCP_HTTP_SHUTDOWN_TIMEOUT_SECONDS", "30.0") 

134 ) 

135 

136 start_ts = time.monotonic() 

137 

138 # 1) Prioritize draining non-streaming requests quickly 

139 drained_non_stream = False 

140 try: 

141 while time.monotonic() - start_ts < drain_timeout: 

142 if not http_handler.has_inflight_non_streaming(): 

143 drained_non_stream = True 

144 logger.info("Non-streaming requests drained; continuing shutdown") 

145 break 

146 await asyncio.sleep(0.1) 

147 except asyncio.CancelledError: 

148 logger.debug("Shutdown monitor cancelled during drain phase") 

149 return 

150 except Exception: 

151 # On any error during drain check, fall through to timeout-based force 

152 pass 

153 

154 if not drained_non_stream: 

155 logger.warning( 

156 f"Non-streaming requests still in flight after {drain_timeout}s; proceeding with shutdown" 

157 ) 

158 

159 # 2) Allow additional time (up to max_shutdown_timeout total) for all requests to complete 

160 total_deadline = start_ts + max_shutdown_timeout 

161 try: 

162 while time.monotonic() < total_deadline: 

163 counts = http_handler.get_inflight_request_counts() 

164 if counts.get("total", 0) == 0: 

165 logger.info("All in-flight requests drained; completing shutdown without force") 

166 break 

167 await asyncio.sleep(0.2) 

168 except asyncio.CancelledError: 

169 logger.debug("Shutdown monitor cancelled during final drain phase") 

170 return 

171 except Exception: 

172 pass 

173 

174 # 3) If still not finished after the max timeout, force the server to exit 

175 if hasattr(server, "force_exit"): 

176 if time.monotonic() >= total_deadline: 

177 logger.warning( 

178 f"Forcing server exit after {max_shutdown_timeout}s shutdown timeout" 

179 ) 

180 server.force_exit = True 

181 else: 

182 logger.debug("Server drained gracefully; force_exit not required") 

183 except asyncio.CancelledError: 

184 logger.debug("Shutdown monitor task cancelled") 

185 return 

186 

187 # Start shutdown monitor task 

188 monitor_task = asyncio.create_task(shutdown_monitor()) 

189 

190 try: 

191 # Run the server until shutdown 

192 await server.serve() 

193 except asyncio.CancelledError: 

194 logger.info("Server shutdown initiated") 

195 except Exception as e: 

196 if not shutdown_event.is_set(): 

197 logger.error(f"Server error: {e}", exc_info=True) 

198 else: 

199 logger.info(f"Server stopped during shutdown: {e}") 

200 finally: 

201 # Clean up the monitor task gracefully 

202 if monitor_task and not monitor_task.done(): 

203 logger.debug("Cleaning up shutdown monitor task") 

204 monitor_task.cancel() 

205 try: 

206 await asyncio.wait_for(monitor_task, timeout=2.0) 

207 except asyncio.CancelledError: 

208 logger.debug("Shutdown monitor task cancelled successfully") 

209 except asyncio.TimeoutError: 

210 logger.warning("Shutdown monitor task cleanup timed out") 

211 except Exception as e: 

212 logger.debug(f"Shutdown monitor cleanup completed with: {e}") 

213 

214 except Exception as e: 

215 if not shutdown_event.is_set(): 

216 logger.error(f"Error in HTTP server: {e}", exc_info=True) 

217 raise 

218 finally: 

219 # Clean up search engine 

220 if search_engine: 

221 try: 

222 await search_engine.cleanup() 

223 logger.info("Search engine cleanup completed") 

224 except Exception as e: 

225 logger.error(f"Error during search engine cleanup: {e}", exc_info=True) 

226 

227 

228async def handle_stdio(config: Config, log_level: str): 

229 """Handle stdio communication with Cursor.""" 

230 logger = LoggingConfig.get_logger(__name__) 

231 

232 try: 

233 # Check if console logging is disabled 

234 disable_console_logging = ( 

235 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true" 

236 ) 

237 

238 if not disable_console_logging: 

239 logger.info("Setting up stdio handler...") 

240 

241 # Initialize components 

242 search_engine = SearchEngine() 

243 query_processor = QueryProcessor(config.openai) 

244 mcp_handler = MCPHandler(search_engine, query_processor) 

245 

246 # Initialize search engine 

247 try: 

248 await search_engine.initialize(config.qdrant, config.openai, config.search) 

249 if not disable_console_logging: 

250 logger.info("Search engine initialized successfully") 

251 except Exception as e: 

252 logger.error("Failed to initialize search engine", exc_info=True) 

253 raise RuntimeError("Failed to initialize search engine") from e 

254 

255 reader = await read_stdin() 

256 if not disable_console_logging: 

257 logger.info("Server ready to handle requests") 

258 

259 while True: 

260 try: 

261 # Read a line from stdin 

262 if not disable_console_logging: 

263 logger.debug("Waiting for input...") 

264 try: 

265 line = await reader.readline() 

266 if not line: 

267 if not disable_console_logging: 

268 logger.warning("No input received, breaking") 

269 break 

270 except asyncio.CancelledError: 

271 if not disable_console_logging: 

272 logger.info("Read operation cancelled during shutdown") 

273 break 

274 

275 # Log the raw input 

276 raw_input = line.decode().strip() 

277 if not disable_console_logging: 

278 logger.debug("Received raw input", raw_input=raw_input) 

279 

280 # Parse the request 

281 try: 

282 request = json.loads(raw_input) 

283 if not disable_console_logging: 

284 logger.debug("Parsed request", request=request) 

285 except json.JSONDecodeError as e: 

286 if not disable_console_logging: 

287 logger.error("Invalid JSON received", error=str(e)) 

288 # Send error response for invalid JSON 

289 response = { 

290 "jsonrpc": "2.0", 

291 "id": None, 

292 "error": { 

293 "code": -32700, 

294 "message": "Parse error", 

295 "data": f"Invalid JSON received: {str(e)}", 

296 }, 

297 } 

298 sys.stdout.write(json.dumps(response) + "\n") 

299 sys.stdout.flush() 

300 continue 

301 

302 # Validate request format 

303 if not isinstance(request, dict): 

304 if not disable_console_logging: 

305 logger.error("Request must be a JSON object") 

306 response = { 

307 "jsonrpc": "2.0", 

308 "id": None, 

309 "error": { 

310 "code": -32600, 

311 "message": "Invalid Request", 

312 "data": "Request must be a JSON object", 

313 }, 

314 } 

315 sys.stdout.write(json.dumps(response) + "\n") 

316 sys.stdout.flush() 

317 continue 

318 

319 if "jsonrpc" not in request or request["jsonrpc"] != "2.0": 

320 if not disable_console_logging: 

321 logger.error("Invalid JSON-RPC version") 

322 response = { 

323 "jsonrpc": "2.0", 

324 "id": request.get("id"), 

325 "error": { 

326 "code": -32600, 

327 "message": "Invalid Request", 

328 "data": "Invalid JSON-RPC version", 

329 }, 

330 } 

331 sys.stdout.write(json.dumps(response) + "\n") 

332 sys.stdout.flush() 

333 continue 

334 

335 # Process the request 

336 try: 

337 response = await mcp_handler.handle_request(request) 

338 if not disable_console_logging: 

339 logger.debug("Sending response", response=response) 

340 # Only write to stdout if response is not empty (not a notification) 

341 if response: 

342 sys.stdout.write(json.dumps(response) + "\n") 

343 sys.stdout.flush() 

344 except Exception as e: 

345 if not disable_console_logging: 

346 logger.error("Error processing request", exc_info=True) 

347 response = { 

348 "jsonrpc": "2.0", 

349 "id": request.get("id"), 

350 "error": { 

351 "code": -32603, 

352 "message": "Internal error", 

353 "data": str(e), 

354 }, 

355 } 

356 sys.stdout.write(json.dumps(response) + "\n") 

357 sys.stdout.flush() 

358 

359 except asyncio.CancelledError: 

360 if not disable_console_logging: 

361 logger.info("Request handling cancelled during shutdown") 

362 break 

363 except Exception: 

364 if not disable_console_logging: 

365 logger.error("Error handling request", exc_info=True) 

366 continue 

367 

368 # Cleanup 

369 await search_engine.cleanup() 

370 

371 except Exception: 

372 if not disable_console_logging: 

373 logger.error("Error in stdio handler", exc_info=True) 

374 raise 

375 

376 

377@click.command(name="mcp-qdrant-loader") 

378@option( 

379 "--log-level", 

380 type=Choice( 

381 ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False 

382 ), 

383 default="INFO", 

384 help="Set the logging level.", 

385) 

386@option( 

387 "--config", 

388 type=ClickPath(exists=True, path_type=Path), 

389 help="Path to configuration file (currently not implemented).", 

390) 

391@option( 

392 "--transport", 

393 type=Choice(["stdio", "http"], case_sensitive=False), 

394 default="stdio", 

395 help="Transport protocol to use (stdio for JSON-RPC over stdin/stdout, http for HTTP with SSE)", 

396) 

397@option( 

398 "--host", 

399 type=str, 

400 default="127.0.0.1", 

401 help="Host to bind HTTP server to (only used with --transport http)", 

402) 

403@option( 

404 "--port", 

405 type=int, 

406 default=8080, 

407 help="Port to bind HTTP server to (only used with --transport http)", 

408) 

409@option( 

410 "--env", 

411 type=ClickPath(exists=True, path_type=Path), 

412 help="Path to .env file to load environment variables from", 

413) 

414@click.version_option( 

415 version=get_version(), 

416 message="QDrant Loader MCP Server v%(version)s", 

417) 

418def cli( 

419 log_level: str = "INFO", 

420 config: Path | None = None, 

421 transport: str = "stdio", 

422 host: str = "127.0.0.1", 

423 port: int = 8080, 

424 env: Path | None = None, 

425) -> None: 

426 """QDrant Loader MCP Server. 

427 

428 A Model Context Protocol (MCP) server that provides RAG capabilities 

429 to Cursor and other LLM applications using Qdrant vector database. 

430 

431 The server supports both stdio (JSON-RPC) and HTTP (with SSE) transports 

432 for maximum compatibility with different MCP clients. 

433 

434 Environment Variables: 

435 QDRANT_URL: URL of your QDrant instance (required) 

436 QDRANT_API_KEY: API key for QDrant authentication 

437 QDRANT_COLLECTION_NAME: Name of the collection to use (default: "documents") 

438 OPENAI_API_KEY: OpenAI API key for embeddings (required) 

439 MCP_DISABLE_CONSOLE_LOGGING: Set to "true" to disable console logging 

440 

441 Examples: 

442 # Start with stdio transport (default, for Cursor/Claude Desktop) 

443 mcp-qdrant-loader 

444 

445 # Start with HTTP transport (for web clients) 

446 mcp-qdrant-loader --transport http --port 8080 

447 

448 # Start with environment variables from .env file 

449 mcp-qdrant-loader --transport http --env /path/to/.env 

450 

451 # Start with debug logging 

452 mcp-qdrant-loader --log-level DEBUG --transport http 

453 

454 # Show help 

455 mcp-qdrant-loader --help 

456 

457 # Show version 

458 mcp-qdrant-loader --version 

459 """ 

460 loop = None 

461 try: 

462 # Load environment variables from .env file if specified 

463 if env: 

464 load_dotenv(env) 

465 # Route message through logger (stderr), not stdout, to avoid polluting stdio transport 

466 LoggingConfig.get_logger(__name__).info( 

467 "Loaded environment variables", env=str(env) 

468 ) 

469 

470 # Setup logging (force-disable console logging in stdio transport) 

471 _setup_logging(log_level, transport) 

472 

473 # Initialize configuration 

474 config_obj = Config() 

475 

476 # Create and set the event loop 

477 loop = asyncio.new_event_loop() 

478 asyncio.set_event_loop(loop) 

479 

480 # Create shutdown event for coordinating graceful shutdown 

481 shutdown_event = asyncio.Event() 

482 shutdown_task = None 

483 

484 # Set up signal handlers with shutdown event 

485 def signal_handler(): 

486 # Schedule shutdown on the explicit loop for clarity and correctness 

487 nonlocal shutdown_task 

488 if shutdown_task is None: 

489 shutdown_task = loop.create_task(shutdown(loop, shutdown_event)) 

490 

491 for sig in (signal.SIGTERM, signal.SIGINT): 

492 loop.add_signal_handler(sig, signal_handler) 

493 

494 # Start the appropriate transport handler 

495 if transport.lower() == "stdio": 

496 loop.run_until_complete(handle_stdio(config_obj, log_level)) 

497 elif transport.lower() == "http": 

498 loop.run_until_complete( 

499 start_http_server(config_obj, log_level, host, port, shutdown_event) 

500 ) 

501 else: 

502 raise ValueError(f"Unsupported transport: {transport}") 

503 except Exception: 

504 logger = LoggingConfig.get_logger(__name__) 

505 logger.error("Error in main", exc_info=True) 

506 sys.exit(1) 

507 finally: 

508 if loop: 

509 try: 

510 # First, wait for the shutdown task if it exists 

511 if 'shutdown_task' in locals() and shutdown_task is not None and not shutdown_task.done(): 

512 try: 

513 logger = LoggingConfig.get_logger(__name__) 

514 logger.debug("Waiting for shutdown task to complete...") 

515 loop.run_until_complete(asyncio.wait_for(shutdown_task, timeout=5.0)) 

516 logger.debug("Shutdown task completed successfully") 

517 except asyncio.TimeoutError: 

518 logger = LoggingConfig.get_logger(__name__) 

519 logger.warning("Shutdown task timed out, cancelling...") 

520 shutdown_task.cancel() 

521 try: 

522 loop.run_until_complete(shutdown_task) 

523 except asyncio.CancelledError: 

524 logger.debug("Shutdown task cancelled successfully") 

525 except Exception as e: 

526 logger = LoggingConfig.get_logger(__name__) 

527 logger.debug(f"Shutdown task completed with: {e}") 

528 

529 # Then cancel any remaining tasks (except completed shutdown task) 

530 def _cancel_all_pending_tasks(): 

531 """Cancel tasks safely without circular dependencies.""" 

532 all_tasks = list(asyncio.all_tasks(loop)) 

533 if not all_tasks: 

534 return 

535 

536 # Cancel all tasks except the completed shutdown task 

537 cancelled_tasks = [] 

538 for task in all_tasks: 

539 if not task.done() and task is not shutdown_task: 

540 task.cancel() 

541 cancelled_tasks.append(task) 

542 

543 # Don't await gather to avoid recursion - just let them finish on their own 

544 # The loop will handle the cleanup when it closes 

545 if cancelled_tasks: 

546 logger = LoggingConfig.get_logger(__name__) 

547 logger.info(f"Cancelled {len(cancelled_tasks)} remaining tasks for cleanup") 

548 

549 _cancel_all_pending_tasks() 

550 except Exception: 

551 logger = LoggingConfig.get_logger(__name__) 

552 logger.error("Error during final cleanup", exc_info=True) 

553 finally: 

554 loop.close() 

555 logger = LoggingConfig.get_logger(__name__) 

556 logger.info("Server shutdown complete") 

557 

558 

559if __name__ == "__main__": 

560 cli()