Coverage for src/qdrant_loader_mcp_server/cli.py: 51%
322 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-11 07:22 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-11 07:22 +0000
1"""CLI module for QDrant Loader MCP Server."""
3import asyncio
4import json
5import logging
6import os
7import signal
8import sys
9import time
10from pathlib import Path
12import click
13from click.decorators import option
14from click.types import Choice
15from click.types import Path as ClickPath
16from dotenv import load_dotenv
18from .config import Config
19from .config_loader import load_config, redact_effective_config
20from .mcp import MCPHandler
21from .search.engine import SearchEngine
22from .search.processor import QueryProcessor
23from .transport import HTTPTransportHandler
24from .utils import LoggingConfig, get_version
26# Suppress asyncio debug messages to reduce noise in logs.
27logging.getLogger("asyncio").setLevel(logging.WARNING)
30def _setup_logging(log_level: str, transport: str | None = None) -> None:
31 """Set up logging configuration."""
32 try:
33 # Force-disable console logging in stdio mode to avoid polluting stdout
34 if transport and transport.lower() == "stdio":
35 os.environ["MCP_DISABLE_CONSOLE_LOGGING"] = "true"
37 # Check if console logging is disabled via environment variable (after any override)
38 disable_console_logging = (
39 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true"
40 )
42 # Reset any pre-existing handlers to prevent duplicate logs when setup() is
43 # invoked implicitly during module imports before CLI config is applied.
44 root_logger = logging.getLogger()
45 for h in list(root_logger.handlers):
46 try:
47 root_logger.removeHandler(h)
48 except Exception:
49 pass
51 # Use reconfigure if available to avoid stacking handlers on repeated setup
52 level = log_level.upper()
53 if getattr(LoggingConfig, "reconfigure", None): # type: ignore[attr-defined]
54 if getattr(LoggingConfig, "_initialized", False): # type: ignore[attr-defined]
55 # Only switch file target (none in stdio; may be env provided)
56 LoggingConfig.reconfigure(file=os.getenv("MCP_LOG_FILE")) # type: ignore[attr-defined]
57 else:
58 LoggingConfig.setup(
59 level=level,
60 format=("json" if disable_console_logging else "console"),
61 )
62 else:
63 # Force replace handlers on older versions
64 logging.getLogger().handlers = []
65 LoggingConfig.setup(
66 level=level, format=("json" if disable_console_logging else "console")
67 )
68 except Exception as e:
69 print(f"Failed to setup logging: {e}", file=sys.stderr)
72async def read_stdin():
73 """Read from stdin asynchronously."""
74 loop = asyncio.get_running_loop()
75 reader = asyncio.StreamReader()
76 protocol = asyncio.StreamReaderProtocol(reader)
77 await loop.connect_read_pipe(lambda: protocol, sys.stdin)
78 return reader
81async def shutdown(
82 loop: asyncio.AbstractEventLoop, shutdown_event: asyncio.Event = None
83):
84 """Handle graceful shutdown."""
85 logger = LoggingConfig.get_logger(__name__)
86 logger.info("Shutting down...")
88 # Only signal shutdown; let server/monitor handle draining and cleanup
89 if shutdown_event:
90 shutdown_event.set()
92 # Yield control so that other tasks (e.g., shutdown monitor, server) can react
93 try:
94 await asyncio.sleep(0)
95 except asyncio.CancelledError:
96 # If shutdown task is cancelled, just exit quietly
97 return
99 logger.info("Shutdown signal dispatched")
102async def start_http_server(
103 config: Config, log_level: str, host: str, port: int, shutdown_event: asyncio.Event
104):
105 """Start MCP server with HTTP transport."""
106 logger = LoggingConfig.get_logger(__name__)
107 search_engine = None
109 try:
110 logger.info(f"Starting HTTP server on {host}:{port}")
112 # Initialize components
113 search_engine = SearchEngine()
114 query_processor = QueryProcessor(config.openai)
115 mcp_handler = MCPHandler(search_engine, query_processor)
117 # Initialize search engine
118 try:
119 await search_engine.initialize(config.qdrant, config.openai, config.search)
120 logger.info("Search engine initialized successfully")
121 except Exception as e:
122 logger.error("Failed to initialize search engine", exc_info=True)
123 raise RuntimeError("Failed to initialize search engine") from e
125 # Create HTTP transport handler
126 http_handler = HTTPTransportHandler(mcp_handler, host=host, port=port)
128 # Start the FastAPI server using uvicorn
129 import uvicorn
131 uvicorn_config = uvicorn.Config(
132 app=http_handler.app,
133 host=host,
134 port=port,
135 log_level=log_level.lower(),
136 access_log=log_level.upper() == "DEBUG",
137 )
139 server = uvicorn.Server(uvicorn_config)
140 logger.info(f"HTTP MCP server ready at http://{host}:{port}/mcp")
142 # Create a task to monitor shutdown event
143 async def shutdown_monitor():
144 try:
145 await shutdown_event.wait()
146 logger.info("Shutdown signal received, stopping HTTP server...")
148 # Signal uvicorn to stop gracefully
149 server.should_exit = True
151 # Graceful drain logic: wait for in-flight requests to finish before forcing exit
152 # Configurable timeouts via environment variables
153 drain_timeout = float(
154 os.getenv("MCP_HTTP_DRAIN_TIMEOUT_SECONDS", "10.0")
155 )
156 max_shutdown_timeout = float(
157 os.getenv("MCP_HTTP_SHUTDOWN_TIMEOUT_SECONDS", "30.0")
158 )
160 start_ts = time.monotonic()
162 # 1) Prioritize draining non-streaming requests quickly
163 drained_non_stream = False
164 try:
165 while time.monotonic() - start_ts < drain_timeout:
166 if not http_handler.has_inflight_non_streaming():
167 drained_non_stream = True
168 logger.info(
169 "Non-streaming requests drained; continuing shutdown"
170 )
171 break
172 await asyncio.sleep(0.1)
173 except asyncio.CancelledError:
174 logger.debug("Shutdown monitor cancelled during drain phase")
175 return
176 except Exception:
177 # On any error during drain check, fall through to timeout-based force
178 pass
180 if not drained_non_stream:
181 logger.warning(
182 f"Non-streaming requests still in flight after {drain_timeout}s; proceeding with shutdown"
183 )
185 # 2) Allow additional time (up to max_shutdown_timeout total) for all requests to complete
186 total_deadline = start_ts + max_shutdown_timeout
187 try:
188 while time.monotonic() < total_deadline:
189 counts = http_handler.get_inflight_request_counts()
190 if counts.get("total", 0) == 0:
191 logger.info(
192 "All in-flight requests drained; completing shutdown without force"
193 )
194 break
195 await asyncio.sleep(0.2)
196 except asyncio.CancelledError:
197 logger.debug("Shutdown monitor cancelled during final drain phase")
198 return
199 except Exception:
200 pass
202 # 3) If still not finished after the max timeout, force the server to exit
203 if hasattr(server, "force_exit"):
204 if time.monotonic() >= total_deadline:
205 logger.warning(
206 f"Forcing server exit after {max_shutdown_timeout}s shutdown timeout"
207 )
208 server.force_exit = True
209 else:
210 logger.debug(
211 "Server drained gracefully; force_exit not required"
212 )
213 except asyncio.CancelledError:
214 logger.debug("Shutdown monitor task cancelled")
215 return
217 # Start shutdown monitor task
218 monitor_task = asyncio.create_task(shutdown_monitor())
220 try:
221 # Run the server until shutdown
222 await server.serve()
223 except asyncio.CancelledError:
224 logger.info("Server shutdown initiated")
225 except Exception as e:
226 if not shutdown_event.is_set():
227 logger.error(f"Server error: {e}", exc_info=True)
228 else:
229 logger.info(f"Server stopped during shutdown: {e}")
230 finally:
231 # Clean up the monitor task gracefully
232 if monitor_task and not monitor_task.done():
233 logger.debug("Cleaning up shutdown monitor task")
234 monitor_task.cancel()
235 try:
236 await asyncio.wait_for(monitor_task, timeout=2.0)
237 except asyncio.CancelledError:
238 logger.debug("Shutdown monitor task cancelled successfully")
239 except TimeoutError:
240 logger.warning("Shutdown monitor task cleanup timed out")
241 except Exception as e:
242 logger.debug(f"Shutdown monitor cleanup completed with: {e}")
244 except Exception as e:
245 if not shutdown_event.is_set():
246 logger.error(f"Error in HTTP server: {e}", exc_info=True)
247 raise
248 finally:
249 # Clean up search engine
250 if search_engine:
251 try:
252 await search_engine.cleanup()
253 logger.info("Search engine cleanup completed")
254 except Exception as e:
255 logger.error(f"Error during search engine cleanup: {e}", exc_info=True)
258async def handle_stdio(config: Config, log_level: str):
259 """Handle stdio communication with Cursor."""
260 logger = LoggingConfig.get_logger(__name__)
262 try:
263 # Check if console logging is disabled
264 disable_console_logging = (
265 os.getenv("MCP_DISABLE_CONSOLE_LOGGING", "").lower() == "true"
266 )
268 if not disable_console_logging:
269 logger.info("Setting up stdio handler...")
271 # Initialize components
272 search_engine = SearchEngine()
273 query_processor = QueryProcessor(config.openai)
274 mcp_handler = MCPHandler(search_engine, query_processor)
276 # Initialize search engine
277 try:
278 await search_engine.initialize(config.qdrant, config.openai, config.search)
279 if not disable_console_logging:
280 logger.info("Search engine initialized successfully")
281 except Exception as e:
282 logger.error("Failed to initialize search engine", exc_info=True)
283 raise RuntimeError("Failed to initialize search engine") from e
285 reader = await read_stdin()
286 if not disable_console_logging:
287 logger.info("Server ready to handle requests")
289 while True:
290 try:
291 # Read a line from stdin
292 if not disable_console_logging:
293 logger.debug("Waiting for input...")
294 try:
295 line = await reader.readline()
296 if not line:
297 if not disable_console_logging:
298 logger.warning("No input received, breaking")
299 break
300 except asyncio.CancelledError:
301 if not disable_console_logging:
302 logger.info("Read operation cancelled during shutdown")
303 break
305 # Log the raw input
306 raw_input = line.decode().strip()
307 if not disable_console_logging:
308 logger.debug("Received raw input", raw_input=raw_input)
310 # Parse the request
311 try:
312 request = json.loads(raw_input)
313 if not disable_console_logging:
314 logger.debug("Parsed request", request=request)
315 except json.JSONDecodeError as e:
316 if not disable_console_logging:
317 logger.error("Invalid JSON received", error=str(e))
318 # Send error response for invalid JSON
319 response = {
320 "jsonrpc": "2.0",
321 "id": None,
322 "error": {
323 "code": -32700,
324 "message": "Parse error",
325 "data": f"Invalid JSON received: {str(e)}",
326 },
327 }
328 sys.stdout.write(json.dumps(response) + "\n")
329 sys.stdout.flush()
330 continue
332 # Validate request format
333 if not isinstance(request, dict):
334 if not disable_console_logging:
335 logger.error("Request must be a JSON object")
336 response = {
337 "jsonrpc": "2.0",
338 "id": None,
339 "error": {
340 "code": -32600,
341 "message": "Invalid Request",
342 "data": "Request must be a JSON object",
343 },
344 }
345 sys.stdout.write(json.dumps(response) + "\n")
346 sys.stdout.flush()
347 continue
349 if "jsonrpc" not in request or request["jsonrpc"] != "2.0":
350 if not disable_console_logging:
351 logger.error("Invalid JSON-RPC version")
352 response = {
353 "jsonrpc": "2.0",
354 "id": request.get("id"),
355 "error": {
356 "code": -32600,
357 "message": "Invalid Request",
358 "data": "Invalid JSON-RPC version",
359 },
360 }
361 sys.stdout.write(json.dumps(response) + "\n")
362 sys.stdout.flush()
363 continue
365 # Process the request
366 try:
367 response = await mcp_handler.handle_request(request)
368 if not disable_console_logging:
369 logger.debug("Sending response", response=response)
370 # Only write to stdout if response is not empty (not a notification)
371 if response:
372 sys.stdout.write(json.dumps(response) + "\n")
373 sys.stdout.flush()
374 except Exception as e:
375 if not disable_console_logging:
376 logger.error("Error processing request", exc_info=True)
377 response = {
378 "jsonrpc": "2.0",
379 "id": request.get("id"),
380 "error": {
381 "code": -32603,
382 "message": "Internal error",
383 "data": str(e),
384 },
385 }
386 sys.stdout.write(json.dumps(response) + "\n")
387 sys.stdout.flush()
389 except asyncio.CancelledError:
390 if not disable_console_logging:
391 logger.info("Request handling cancelled during shutdown")
392 break
393 except Exception:
394 if not disable_console_logging:
395 logger.error("Error handling request", exc_info=True)
396 continue
398 # Cleanup
399 await search_engine.cleanup()
401 except Exception:
402 if not disable_console_logging:
403 logger.error("Error in stdio handler", exc_info=True)
404 raise
407@click.command(name="mcp-qdrant-loader")
408@option(
409 "--log-level",
410 type=Choice(
411 ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
412 ),
413 default="INFO",
414 help="Set the logging level.",
415)
416# Hidden option to print effective config (redacts secrets)
417@option(
418 "--print-config",
419 is_flag=True,
420 default=False,
421 help="Print the effective configuration (secrets redacted) and exit.",
422)
423@option(
424 "--config",
425 type=ClickPath(exists=True, path_type=Path),
426 help="Path to configuration file.",
427)
428@option(
429 "--transport",
430 type=Choice(["stdio", "http"], case_sensitive=False),
431 default="stdio",
432 help="Transport protocol to use (stdio for JSON-RPC over stdin/stdout, http for HTTP with SSE)",
433)
434@option(
435 "--host",
436 type=str,
437 default="127.0.0.1",
438 help="Host to bind HTTP server to (only used with --transport http)",
439)
440@option(
441 "--port",
442 type=int,
443 default=8080,
444 help="Port to bind HTTP server to (only used with --transport http)",
445)
446@option(
447 "--env",
448 type=ClickPath(exists=True, path_type=Path),
449 help="Path to .env file to load environment variables from",
450)
451@click.version_option(
452 version=get_version(),
453 message="QDrant Loader MCP Server v%(version)s",
454)
455def cli(
456 log_level: str = "INFO",
457 config: Path | None = None,
458 transport: str = "stdio",
459 host: str = "127.0.0.1",
460 port: int = 8080,
461 env: Path | None = None,
462 print_config: bool = False,
463) -> None:
464 """QDrant Loader MCP Server.
466 A Model Context Protocol (MCP) server that provides RAG capabilities
467 to Cursor and other LLM applications using Qdrant vector database.
469 The server supports both stdio (JSON-RPC) and HTTP (with SSE) transports
470 for maximum compatibility with different MCP clients.
472 Environment Variables:
473 QDRANT_URL: URL of your QDrant instance (required)
474 QDRANT_API_KEY: API key for QDrant authentication
475 QDRANT_COLLECTION_NAME: Name of the collection to use (default: "documents")
476 OPENAI_API_KEY: OpenAI API key for embeddings (required)
477 MCP_DISABLE_CONSOLE_LOGGING: Set to "true" to disable console logging
479 Examples:
480 # Start with stdio transport (default, for Cursor/Claude Desktop)
481 mcp-qdrant-loader
483 # Start with HTTP transport (for web clients)
484 mcp-qdrant-loader --transport http --port 8080
486 # Start with environment variables from .env file
487 mcp-qdrant-loader --transport http --env /path/to/.env
489 # Start with debug logging
490 mcp-qdrant-loader --log-level DEBUG --transport http
492 # Show help
493 mcp-qdrant-loader --help
495 # Show version
496 mcp-qdrant-loader --version
497 """
498 loop = None
499 try:
500 # Load environment variables from .env file if specified
501 if env:
502 load_dotenv(env)
504 # Setup logging (force-disable console logging in stdio transport)
505 _setup_logging(log_level, transport)
507 # Log env file load after logging is configured to avoid duplicate handler setup
508 if env:
509 LoggingConfig.get_logger(__name__).info(
510 "Loaded environment variables", env=str(env)
511 )
513 # If a config file was provided, propagate it via MCP_CONFIG so that
514 # any internal callers that resolve config without CLI context can find it.
515 if config is not None:
516 try:
517 os.environ["MCP_CONFIG"] = str(config)
518 except Exception:
519 # Best-effort; continue without blocking startup
520 pass
522 # Initialize configuration (file/env precedence)
523 config_obj, effective_cfg, used_file = load_config(config)
525 if print_config:
526 redacted = redact_effective_config(effective_cfg)
527 click.echo(json.dumps(redacted, indent=2))
528 return
530 # Create and set the event loop
531 loop = asyncio.new_event_loop()
532 asyncio.set_event_loop(loop)
534 # Create shutdown event for coordinating graceful shutdown
535 shutdown_event = asyncio.Event()
536 shutdown_task = None
538 # Set up signal handlers with shutdown event
539 def signal_handler():
540 # Schedule shutdown on the explicit loop for clarity and correctness
541 nonlocal shutdown_task
542 if shutdown_task is None:
543 shutdown_task = loop.create_task(shutdown(loop, shutdown_event))
545 for sig in (signal.SIGTERM, signal.SIGINT):
546 loop.add_signal_handler(sig, signal_handler)
548 # Start the appropriate transport handler
549 if transport.lower() == "stdio":
550 loop.run_until_complete(handle_stdio(config_obj, log_level))
551 elif transport.lower() == "http":
552 loop.run_until_complete(
553 start_http_server(config_obj, log_level, host, port, shutdown_event)
554 )
555 else:
556 raise ValueError(f"Unsupported transport: {transport}")
557 except Exception:
558 logger = LoggingConfig.get_logger(__name__)
559 logger.error("Error in main", exc_info=True)
560 sys.exit(1)
561 finally:
562 if loop:
563 try:
564 # First, wait for the shutdown task if it exists
565 if (
566 "shutdown_task" in locals()
567 and shutdown_task is not None
568 and not shutdown_task.done()
569 ):
570 try:
571 logger = LoggingConfig.get_logger(__name__)
572 logger.debug("Waiting for shutdown task to complete...")
573 loop.run_until_complete(
574 asyncio.wait_for(shutdown_task, timeout=5.0)
575 )
576 logger.debug("Shutdown task completed successfully")
577 except TimeoutError:
578 logger = LoggingConfig.get_logger(__name__)
579 logger.warning("Shutdown task timed out, cancelling...")
580 shutdown_task.cancel()
581 try:
582 loop.run_until_complete(shutdown_task)
583 except asyncio.CancelledError:
584 logger.debug("Shutdown task cancelled successfully")
585 except Exception as e:
586 logger = LoggingConfig.get_logger(__name__)
587 logger.debug(f"Shutdown task completed with: {e}")
589 # Then cancel any remaining tasks (except completed shutdown task)
590 def _cancel_all_pending_tasks():
591 """Cancel tasks safely without circular dependencies."""
592 all_tasks = list(asyncio.all_tasks(loop))
593 if not all_tasks:
594 return
596 # Cancel all tasks except the completed shutdown task
597 cancelled_tasks = []
598 for task in all_tasks:
599 if not task.done() and task is not shutdown_task:
600 task.cancel()
601 cancelled_tasks.append(task)
603 # Don't await gather to avoid recursion - just let them finish on their own
604 # The loop will handle the cleanup when it closes
605 if cancelled_tasks:
606 logger = LoggingConfig.get_logger(__name__)
607 logger.info(
608 f"Cancelled {len(cancelled_tasks)} remaining tasks for cleanup"
609 )
611 _cancel_all_pending_tasks()
612 except Exception:
613 logger = LoggingConfig.get_logger(__name__)
614 logger.error("Error during final cleanup", exc_info=True)
615 finally:
616 loop.close()
617 logger = LoggingConfig.get_logger(__name__)
618 logger.info("Server shutdown complete")
621if __name__ == "__main__":
622 cli()