Coverage for src/qdrant_loader_mcp_server/cli.py: 51%
300 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:20 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:20 +0000
1"""CLI module for QDrant Loader MCP Server."""
3import asyncio
4import json
5import logging
6import os
7import signal
8import sys
9import time
10from pathlib import Path
12import click
13from click.decorators import option
14from click.types import Choice
15from click.types import Path as ClickPath
16from dotenv import load_dotenv
18from .config import Config
19from .mcp import MCPHandler
20from .search.engine import SearchEngine
21from .search.processor import QueryProcessor
22from .transport import HTTPTransportHandler
23from .utils import LoggingConfig, get_version
25# Suppress asyncio debug messages to reduce noise in logs.
26logging.getLogger("asyncio").setLevel(logging.WARNING)
29def _setup_logging(log_level: str, transport: str | None = None) -> None:
30 """Set up logging configuration."""
31 try:
32 # Force-disable console logging in stdio mode to avoid polluting stdout
33 if transport and transport.lower() == "stdio":
34 os.environ["MCP_DISABLE_CONSOLE_LOGGING"] = "true"
36 # Check if console logging is disabled via environment variable (after any override)
37 disable_console_logging = (
38 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true"
39 )
41 if not disable_console_logging:
42 # Console format goes to stderr via our logging config
43 LoggingConfig.setup(level=log_level.upper(), format="console")
44 else:
45 LoggingConfig.setup(level=log_level.upper(), format="json")
46 except Exception as e:
47 print(f"Failed to setup logging: {e}", file=sys.stderr)
50async def read_stdin():
51 """Read from stdin asynchronously."""
52 loop = asyncio.get_running_loop()
53 reader = asyncio.StreamReader()
54 protocol = asyncio.StreamReaderProtocol(reader)
55 await loop.connect_read_pipe(lambda: protocol, sys.stdin)
56 return reader
59async def shutdown(
60 loop: asyncio.AbstractEventLoop, shutdown_event: asyncio.Event = None
61):
62 """Handle graceful shutdown."""
63 logger = LoggingConfig.get_logger(__name__)
64 logger.info("Shutting down...")
66 # Only signal shutdown; let server/monitor handle draining and cleanup
67 if shutdown_event:
68 shutdown_event.set()
70 # Yield control so that other tasks (e.g., shutdown monitor, server) can react
71 try:
72 await asyncio.sleep(0)
73 except asyncio.CancelledError:
74 # If shutdown task is cancelled, just exit quietly
75 return
77 logger.info("Shutdown signal dispatched")
80async def start_http_server(
81 config: Config, log_level: str, host: str, port: int, shutdown_event: asyncio.Event
82):
83 """Start MCP server with HTTP transport."""
84 logger = LoggingConfig.get_logger(__name__)
85 search_engine = None
87 try:
88 logger.info(f"Starting HTTP server on {host}:{port}")
90 # Initialize components
91 search_engine = SearchEngine()
92 query_processor = QueryProcessor(config.openai)
93 mcp_handler = MCPHandler(search_engine, query_processor)
95 # Initialize search engine
96 try:
97 await search_engine.initialize(config.qdrant, config.openai, config.search)
98 logger.info("Search engine initialized successfully")
99 except Exception as e:
100 logger.error("Failed to initialize search engine", exc_info=True)
101 raise RuntimeError("Failed to initialize search engine") from e
103 # Create HTTP transport handler
104 http_handler = HTTPTransportHandler(mcp_handler, host=host, port=port)
106 # Start the FastAPI server using uvicorn
107 import uvicorn
109 uvicorn_config = uvicorn.Config(
110 app=http_handler.app,
111 host=host,
112 port=port,
113 log_level=log_level.lower(),
114 access_log=log_level.upper() == "DEBUG",
115 )
117 server = uvicorn.Server(uvicorn_config)
118 logger.info(f"HTTP MCP server ready at http://{host}:{port}/mcp")
120 # Create a task to monitor shutdown event
121 async def shutdown_monitor():
122 try:
123 await shutdown_event.wait()
124 logger.info("Shutdown signal received, stopping HTTP server...")
126 # Signal uvicorn to stop gracefully
127 server.should_exit = True
129 # Graceful drain logic: wait for in-flight requests to finish before forcing exit
130 # Configurable timeouts via environment variables
131 drain_timeout = float(os.getenv("MCP_HTTP_DRAIN_TIMEOUT_SECONDS", "10.0"))
132 max_shutdown_timeout = float(
133 os.getenv("MCP_HTTP_SHUTDOWN_TIMEOUT_SECONDS", "30.0")
134 )
136 start_ts = time.monotonic()
138 # 1) Prioritize draining non-streaming requests quickly
139 drained_non_stream = False
140 try:
141 while time.monotonic() - start_ts < drain_timeout:
142 if not http_handler.has_inflight_non_streaming():
143 drained_non_stream = True
144 logger.info("Non-streaming requests drained; continuing shutdown")
145 break
146 await asyncio.sleep(0.1)
147 except asyncio.CancelledError:
148 logger.debug("Shutdown monitor cancelled during drain phase")
149 return
150 except Exception:
151 # On any error during drain check, fall through to timeout-based force
152 pass
154 if not drained_non_stream:
155 logger.warning(
156 f"Non-streaming requests still in flight after {drain_timeout}s; proceeding with shutdown"
157 )
159 # 2) Allow additional time (up to max_shutdown_timeout total) for all requests to complete
160 total_deadline = start_ts + max_shutdown_timeout
161 try:
162 while time.monotonic() < total_deadline:
163 counts = http_handler.get_inflight_request_counts()
164 if counts.get("total", 0) == 0:
165 logger.info("All in-flight requests drained; completing shutdown without force")
166 break
167 await asyncio.sleep(0.2)
168 except asyncio.CancelledError:
169 logger.debug("Shutdown monitor cancelled during final drain phase")
170 return
171 except Exception:
172 pass
174 # 3) If still not finished after the max timeout, force the server to exit
175 if hasattr(server, "force_exit"):
176 if time.monotonic() >= total_deadline:
177 logger.warning(
178 f"Forcing server exit after {max_shutdown_timeout}s shutdown timeout"
179 )
180 server.force_exit = True
181 else:
182 logger.debug("Server drained gracefully; force_exit not required")
183 except asyncio.CancelledError:
184 logger.debug("Shutdown monitor task cancelled")
185 return
187 # Start shutdown monitor task
188 monitor_task = asyncio.create_task(shutdown_monitor())
190 try:
191 # Run the server until shutdown
192 await server.serve()
193 except asyncio.CancelledError:
194 logger.info("Server shutdown initiated")
195 except Exception as e:
196 if not shutdown_event.is_set():
197 logger.error(f"Server error: {e}", exc_info=True)
198 else:
199 logger.info(f"Server stopped during shutdown: {e}")
200 finally:
201 # Clean up the monitor task gracefully
202 if monitor_task and not monitor_task.done():
203 logger.debug("Cleaning up shutdown monitor task")
204 monitor_task.cancel()
205 try:
206 await asyncio.wait_for(monitor_task, timeout=2.0)
207 except asyncio.CancelledError:
208 logger.debug("Shutdown monitor task cancelled successfully")
209 except asyncio.TimeoutError:
210 logger.warning("Shutdown monitor task cleanup timed out")
211 except Exception as e:
212 logger.debug(f"Shutdown monitor cleanup completed with: {e}")
214 except Exception as e:
215 if not shutdown_event.is_set():
216 logger.error(f"Error in HTTP server: {e}", exc_info=True)
217 raise
218 finally:
219 # Clean up search engine
220 if search_engine:
221 try:
222 await search_engine.cleanup()
223 logger.info("Search engine cleanup completed")
224 except Exception as e:
225 logger.error(f"Error during search engine cleanup: {e}", exc_info=True)
228async def handle_stdio(config: Config, log_level: str):
229 """Handle stdio communication with Cursor."""
230 logger = LoggingConfig.get_logger(__name__)
232 try:
233 # Check if console logging is disabled
234 disable_console_logging = (
235 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true"
236 )
238 if not disable_console_logging:
239 logger.info("Setting up stdio handler...")
241 # Initialize components
242 search_engine = SearchEngine()
243 query_processor = QueryProcessor(config.openai)
244 mcp_handler = MCPHandler(search_engine, query_processor)
246 # Initialize search engine
247 try:
248 await search_engine.initialize(config.qdrant, config.openai, config.search)
249 if not disable_console_logging:
250 logger.info("Search engine initialized successfully")
251 except Exception as e:
252 logger.error("Failed to initialize search engine", exc_info=True)
253 raise RuntimeError("Failed to initialize search engine") from e
255 reader = await read_stdin()
256 if not disable_console_logging:
257 logger.info("Server ready to handle requests")
259 while True:
260 try:
261 # Read a line from stdin
262 if not disable_console_logging:
263 logger.debug("Waiting for input...")
264 try:
265 line = await reader.readline()
266 if not line:
267 if not disable_console_logging:
268 logger.warning("No input received, breaking")
269 break
270 except asyncio.CancelledError:
271 if not disable_console_logging:
272 logger.info("Read operation cancelled during shutdown")
273 break
275 # Log the raw input
276 raw_input = line.decode().strip()
277 if not disable_console_logging:
278 logger.debug("Received raw input", raw_input=raw_input)
280 # Parse the request
281 try:
282 request = json.loads(raw_input)
283 if not disable_console_logging:
284 logger.debug("Parsed request", request=request)
285 except json.JSONDecodeError as e:
286 if not disable_console_logging:
287 logger.error("Invalid JSON received", error=str(e))
288 # Send error response for invalid JSON
289 response = {
290 "jsonrpc": "2.0",
291 "id": None,
292 "error": {
293 "code": -32700,
294 "message": "Parse error",
295 "data": f"Invalid JSON received: {str(e)}",
296 },
297 }
298 sys.stdout.write(json.dumps(response) + "\n")
299 sys.stdout.flush()
300 continue
302 # Validate request format
303 if not isinstance(request, dict):
304 if not disable_console_logging:
305 logger.error("Request must be a JSON object")
306 response = {
307 "jsonrpc": "2.0",
308 "id": None,
309 "error": {
310 "code": -32600,
311 "message": "Invalid Request",
312 "data": "Request must be a JSON object",
313 },
314 }
315 sys.stdout.write(json.dumps(response) + "\n")
316 sys.stdout.flush()
317 continue
319 if "jsonrpc" not in request or request["jsonrpc"] != "2.0":
320 if not disable_console_logging:
321 logger.error("Invalid JSON-RPC version")
322 response = {
323 "jsonrpc": "2.0",
324 "id": request.get("id"),
325 "error": {
326 "code": -32600,
327 "message": "Invalid Request",
328 "data": "Invalid JSON-RPC version",
329 },
330 }
331 sys.stdout.write(json.dumps(response) + "\n")
332 sys.stdout.flush()
333 continue
335 # Process the request
336 try:
337 response = await mcp_handler.handle_request(request)
338 if not disable_console_logging:
339 logger.debug("Sending response", response=response)
340 # Only write to stdout if response is not empty (not a notification)
341 if response:
342 sys.stdout.write(json.dumps(response) + "\n")
343 sys.stdout.flush()
344 except Exception as e:
345 if not disable_console_logging:
346 logger.error("Error processing request", exc_info=True)
347 response = {
348 "jsonrpc": "2.0",
349 "id": request.get("id"),
350 "error": {
351 "code": -32603,
352 "message": "Internal error",
353 "data": str(e),
354 },
355 }
356 sys.stdout.write(json.dumps(response) + "\n")
357 sys.stdout.flush()
359 except asyncio.CancelledError:
360 if not disable_console_logging:
361 logger.info("Request handling cancelled during shutdown")
362 break
363 except Exception:
364 if not disable_console_logging:
365 logger.error("Error handling request", exc_info=True)
366 continue
368 # Cleanup
369 await search_engine.cleanup()
371 except Exception:
372 if not disable_console_logging:
373 logger.error("Error in stdio handler", exc_info=True)
374 raise
377@click.command(name="mcp-qdrant-loader")
378@option(
379 "--log-level",
380 type=Choice(
381 ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
382 ),
383 default="INFO",
384 help="Set the logging level.",
385)
386@option(
387 "--config",
388 type=ClickPath(exists=True, path_type=Path),
389 help="Path to configuration file (currently not implemented).",
390)
391@option(
392 "--transport",
393 type=Choice(["stdio", "http"], case_sensitive=False),
394 default="stdio",
395 help="Transport protocol to use (stdio for JSON-RPC over stdin/stdout, http for HTTP with SSE)",
396)
397@option(
398 "--host",
399 type=str,
400 default="127.0.0.1",
401 help="Host to bind HTTP server to (only used with --transport http)",
402)
403@option(
404 "--port",
405 type=int,
406 default=8080,
407 help="Port to bind HTTP server to (only used with --transport http)",
408)
409@option(
410 "--env",
411 type=ClickPath(exists=True, path_type=Path),
412 help="Path to .env file to load environment variables from",
413)
414@click.version_option(
415 version=get_version(),
416 message="QDrant Loader MCP Server v%(version)s",
417)
418def cli(
419 log_level: str = "INFO",
420 config: Path | None = None,
421 transport: str = "stdio",
422 host: str = "127.0.0.1",
423 port: int = 8080,
424 env: Path | None = None,
425) -> None:
426 """QDrant Loader MCP Server.
428 A Model Context Protocol (MCP) server that provides RAG capabilities
429 to Cursor and other LLM applications using Qdrant vector database.
431 The server supports both stdio (JSON-RPC) and HTTP (with SSE) transports
432 for maximum compatibility with different MCP clients.
434 Environment Variables:
435 QDRANT_URL: URL of your QDrant instance (required)
436 QDRANT_API_KEY: API key for QDrant authentication
437 QDRANT_COLLECTION_NAME: Name of the collection to use (default: "documents")
438 OPENAI_API_KEY: OpenAI API key for embeddings (required)
439 MCP_DISABLE_CONSOLE_LOGGING: Set to "true" to disable console logging
441 Examples:
442 # Start with stdio transport (default, for Cursor/Claude Desktop)
443 mcp-qdrant-loader
445 # Start with HTTP transport (for web clients)
446 mcp-qdrant-loader --transport http --port 8080
448 # Start with environment variables from .env file
449 mcp-qdrant-loader --transport http --env /path/to/.env
451 # Start with debug logging
452 mcp-qdrant-loader --log-level DEBUG --transport http
454 # Show help
455 mcp-qdrant-loader --help
457 # Show version
458 mcp-qdrant-loader --version
459 """
460 loop = None
461 try:
462 # Load environment variables from .env file if specified
463 if env:
464 load_dotenv(env)
465 # Route message through logger (stderr), not stdout, to avoid polluting stdio transport
466 LoggingConfig.get_logger(__name__).info(
467 "Loaded environment variables", env=str(env)
468 )
470 # Setup logging (force-disable console logging in stdio transport)
471 _setup_logging(log_level, transport)
473 # Initialize configuration
474 config_obj = Config()
476 # Create and set the event loop
477 loop = asyncio.new_event_loop()
478 asyncio.set_event_loop(loop)
480 # Create shutdown event for coordinating graceful shutdown
481 shutdown_event = asyncio.Event()
482 shutdown_task = None
484 # Set up signal handlers with shutdown event
485 def signal_handler():
486 # Schedule shutdown on the explicit loop for clarity and correctness
487 nonlocal shutdown_task
488 if shutdown_task is None:
489 shutdown_task = loop.create_task(shutdown(loop, shutdown_event))
491 for sig in (signal.SIGTERM, signal.SIGINT):
492 loop.add_signal_handler(sig, signal_handler)
494 # Start the appropriate transport handler
495 if transport.lower() == "stdio":
496 loop.run_until_complete(handle_stdio(config_obj, log_level))
497 elif transport.lower() == "http":
498 loop.run_until_complete(
499 start_http_server(config_obj, log_level, host, port, shutdown_event)
500 )
501 else:
502 raise ValueError(f"Unsupported transport: {transport}")
503 except Exception:
504 logger = LoggingConfig.get_logger(__name__)
505 logger.error("Error in main", exc_info=True)
506 sys.exit(1)
507 finally:
508 if loop:
509 try:
510 # First, wait for the shutdown task if it exists
511 if 'shutdown_task' in locals() and shutdown_task is not None and not shutdown_task.done():
512 try:
513 logger = LoggingConfig.get_logger(__name__)
514 logger.debug("Waiting for shutdown task to complete...")
515 loop.run_until_complete(asyncio.wait_for(shutdown_task, timeout=5.0))
516 logger.debug("Shutdown task completed successfully")
517 except asyncio.TimeoutError:
518 logger = LoggingConfig.get_logger(__name__)
519 logger.warning("Shutdown task timed out, cancelling...")
520 shutdown_task.cancel()
521 try:
522 loop.run_until_complete(shutdown_task)
523 except asyncio.CancelledError:
524 logger.debug("Shutdown task cancelled successfully")
525 except Exception as e:
526 logger = LoggingConfig.get_logger(__name__)
527 logger.debug(f"Shutdown task completed with: {e}")
529 # Then cancel any remaining tasks (except completed shutdown task)
530 def _cancel_all_pending_tasks():
531 """Cancel tasks safely without circular dependencies."""
532 all_tasks = list(asyncio.all_tasks(loop))
533 if not all_tasks:
534 return
536 # Cancel all tasks except the completed shutdown task
537 cancelled_tasks = []
538 for task in all_tasks:
539 if not task.done() and task is not shutdown_task:
540 task.cancel()
541 cancelled_tasks.append(task)
543 # Don't await gather to avoid recursion - just let them finish on their own
544 # The loop will handle the cleanup when it closes
545 if cancelled_tasks:
546 logger = LoggingConfig.get_logger(__name__)
547 logger.info(f"Cancelled {len(cancelled_tasks)} remaining tasks for cleanup")
549 _cancel_all_pending_tasks()
550 except Exception:
551 logger = LoggingConfig.get_logger(__name__)
552 logger.error("Error during final cleanup", exc_info=True)
553 finally:
554 loop.close()
555 logger = LoggingConfig.get_logger(__name__)
556 logger.info("Server shutdown complete")
559if __name__ == "__main__":
560 cli()