Coverage for src/qdrant_loader_mcp_server/transport/http_handler.py: 81%
129 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:20 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-13 09:20 +0000
1"""HTTP Transport Handler for MCP Protocol 2025-06-18."""
3import asyncio
4import json
5import time
6from typing import Any
8from fastapi import FastAPI, HTTPException, Request, Response
9from fastapi.middleware.cors import CORSMiddleware
10from fastapi.responses import StreamingResponse
12from ..utils.logging import LoggingConfig
14logger = LoggingConfig.get_logger(__name__)
17class HTTPTransportHandler:
18 """HTTP Transport Handler for MCP Protocol with SSE streaming support."""
20 def __init__(self, mcp_handler, host: str = "127.0.0.1", port: int = 8080):
21 """Initialize HTTP transport handler.
23 Args:
24 mcp_handler: The MCP handler instance to process requests
25 host: Host to bind to (default: 127.0.0.1 for security)
26 port: Port to bind to (default: 8080)
27 """
28 self.mcp_handler = mcp_handler
29 self.host = host
30 self.port = port
31 self.app = FastAPI(
32 title="QDrant Loader MCP Server",
33 description="HTTP transport for Model Context Protocol",
34 version="1.0.0",
35 )
36 self.sessions: dict[str, dict[str, Any]] = {}
37 # Track in-flight requests to support graceful shutdown
38 self._inflight_requests: int = 0
39 self._inflight_non_stream_requests: int = 0
40 self._counter_lock = asyncio.Lock()
41 self._setup_middleware()
42 self._setup_routes()
43 logger.info(f"HTTP transport handler initialized on {host}:{port}")
45 def _setup_middleware(self):
46 """Setup FastAPI middleware for CORS and security."""
47 # Add CORS middleware for browser clients
48 self.app.add_middleware(
49 CORSMiddleware,
50 allow_origin_regex=r"https?://(localhost|127\.0\.0\.1)(:[0-9]+)?",
51 allow_credentials=True,
52 allow_methods=["GET", "POST", "OPTIONS"],
53 allow_headers=["*"],
54 )
56 @self.app.middleware("http")
57 async def _track_inflight_requests(request: Request, call_next):
58 """Track in-flight requests to enable graceful shutdown.
60 We treat GET /mcp as a streaming (SSE) request and track it separately so
61 shutdown logic can prioritize draining non-streaming requests first.
62 """
63 is_stream = request.method.upper() == "GET" and request.url.path == "/mcp"
64 async with self._counter_lock:
65 self._inflight_requests += 1
66 if not is_stream:
67 self._inflight_non_stream_requests += 1
69 try:
70 response = await call_next(request)
71 return response
72 finally:
73 async with self._counter_lock:
74 # Guard against going below zero in unexpected edge cases
75 self._inflight_requests = max(0, self._inflight_requests - 1)
76 if not is_stream:
77 self._inflight_non_stream_requests = max(
78 0, self._inflight_non_stream_requests - 1
79 )
81 def _setup_routes(self):
82 """Setup FastAPI routes for MCP endpoints."""
84 @self.app.post("/mcp")
85 async def handle_mcp_post(request: Request):
86 """Handle client-to-server messages via HTTP POST."""
87 logger.debug("Received POST request to /mcp")
88 return await self._handle_post_request(request)
90 @self.app.get("/mcp")
91 async def handle_mcp_get(request: Request):
92 """Handle server-to-client streaming via SSE."""
93 logger.debug("Received GET request to /mcp for SSE streaming")
94 return await self._handle_get_request(request)
96 @self.app.options("/mcp")
97 async def handle_mcp_options():
98 """Handle CORS preflight requests for /mcp endpoint."""
99 return Response(
100 status_code=200,
101 headers={
102 "Access-Control-Allow-Origin": "*",
103 "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
104 "Access-Control-Allow-Headers": "*",
105 "Access-Control-Allow-Credentials": "true",
106 },
107 )
109 @self.app.get("/health")
110 async def health_check():
111 """Health check endpoint."""
112 return {"status": "healthy", "transport": "http", "protocol": "mcp"}
114 async def _handle_post_request(self, request: Request) -> dict[str, Any]:
115 """Process MCP messages from HTTP POST requests.
117 Args:
118 request: FastAPI request object
120 Returns:
121 MCP response dictionary
122 """
123 try:
124 # Security: Validate Origin header (DNS rebinding protection)
125 origin = request.headers.get("origin")
126 if not self._validate_origin(origin):
127 logger.warning(f"Invalid origin header: {origin}")
128 raise HTTPException(status_code=403, detail="Invalid origin")
130 # Protocol version validation (optional with backward compatibility)
131 protocol_version = request.headers.get("mcp-protocol-version")
132 if not self._validate_protocol_version(protocol_version):
133 logger.warning(f"Unsupported protocol version: {protocol_version}")
134 # Continue with warning but don't reject for backward compatibility
136 # Session management
137 session_id = request.headers.get("mcp-session-id")
138 if not session_id:
139 session_id = f"session_{int(time.time() * 1000)}"
140 logger.debug(f"Generated new session ID: {session_id}")
142 # Process MCP request
143 mcp_request = await request.json()
144 logger.debug(
145 f"Processing MCP request: {mcp_request.get('method', 'unknown')}"
146 )
148 # Add headers context to request processing
149 response = await self.mcp_handler.handle_request(
150 mcp_request, headers=dict(request.headers)
151 )
153 # Store response for SSE streaming if needed
154 if session_id not in self.sessions:
155 self.sessions[session_id] = {"messages": [], "created_at": time.time()}
157 # Store any server-initiated messages for this session
158 # (for future elicitation support)
160 logger.debug("Successfully processed MCP request, returning response")
161 return response
163 except HTTPException:
164 # Re-raise HTTPException so FastAPI can handle it properly
165 raise
166 except json.JSONDecodeError as e:
167 logger.error(f"JSON decode error: {e}")
168 return {
169 "jsonrpc": "2.0",
170 "error": {"code": -32700, "message": "Invalid JSON in request"},
171 }
172 except Exception as e:
173 logger.error(f"Error processing MCP request: {e}", exc_info=True)
174 return {
175 "jsonrpc": "2.0",
176 "error": {"code": -32603, "message": "Internal server error"},
177 }
179 async def _handle_get_request(self, request: Request) -> StreamingResponse:
180 """Handle SSE streaming for server-to-client messages.
182 Args:
183 request: FastAPI request object
185 Returns:
186 StreamingResponse with SSE events
187 """
188 session_id = request.headers.get("mcp-session-id")
189 logger.debug(f"Setting up SSE stream for session: {session_id}")
191 async def event_stream():
192 """Generate SSE events for the session."""
193 try:
194 while True:
195 # Check for new messages in session
196 if session_id and session_id in self.sessions:
197 session = self.sessions[session_id]
198 if session.get("messages"):
199 for message in session["messages"]:
200 logger.debug(f"Sending SSE message: {message}")
201 yield f"data: {json.dumps(message)}\n\n"
202 session["messages"] = [] # Clear sent messages
204 # Send heartbeat to keep connection alive
205 yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
207 await asyncio.sleep(1.0) # Poll interval (1 second)
209 except asyncio.CancelledError:
210 logger.debug(f"SSE stream cancelled for session: {session_id}")
211 raise
212 except Exception as e:
213 logger.error(f"Error in SSE stream: {e}", exc_info=True)
214 yield f"data: {json.dumps({'type': 'error', 'message': 'Stream processing error'})}\n\n"
216 return StreamingResponse(
217 event_stream(),
218 media_type="text/event-stream",
219 headers={
220 "Cache-Control": "no-cache",
221 "Connection": "keep-alive",
222 "Access-Control-Allow-Origin": "*",
223 "Access-Control-Allow-Headers": "*",
224 "X-Accel-Buffering": "no", # Disable nginx buffering
225 },
226 )
228 def get_inflight_request_counts(self) -> dict[str, int]:
229 """Return current in-flight request counters.
231 - total: All requests currently being processed
232 - non_streaming: Requests that are not long-lived streams (e.g., SSE)
233 """
234 return {
235 "total": self._inflight_requests,
236 "non_streaming": self._inflight_non_stream_requests,
237 }
239 def has_inflight_non_streaming(self) -> bool:
240 """Whether there are non-streaming in-flight requests."""
241 return self._inflight_non_stream_requests > 0
243 def _validate_origin(self, origin: str | None) -> bool:
244 """Validate Origin header to prevent DNS rebinding attacks.
246 Args:
247 origin: Origin header value
249 Returns:
250 True if origin is valid, False otherwise
251 """
252 if not origin:
253 # Allow requests without Origin (non-browser clients)
254 return True
256 # Only allow localhost origins for security
257 allowed_origins = [
258 "http://localhost",
259 "https://localhost",
260 "http://127.0.0.1",
261 "https://127.0.0.1",
262 ]
264 # Check if origin starts with any allowed origin
265 # (to handle different ports like http://localhost:3000)
266 return any(origin.startswith(allowed) for allowed in allowed_origins)
268 def _validate_protocol_version(self, version: str | None) -> bool:
269 """Validate MCP protocol version header.
271 Args:
272 version: Protocol version from header
274 Returns:
275 True if version is supported, False otherwise
276 """
277 if not version:
278 # Allow for backward compatibility
279 return True
281 # Supported protocol versions
282 supported_versions = ["2025-06-18", "2025-03-26", "2024-11-05"]
283 return version in supported_versions
285 def add_session_message(self, session_id: str, message: dict[str, Any]):
286 """Add a message to a session for SSE streaming.
288 Args:
289 session_id: Session identifier
290 message: Message to add to session queue
291 """
292 if session_id in self.sessions:
293 self.sessions[session_id]["messages"].append(message)
294 logger.debug(f"Added message to session {session_id}: {message}")
296 def cleanup_sessions(self, max_age_seconds: int = 3600):
297 """Clean up old sessions.
299 Args:
300 max_age_seconds: Maximum age of sessions in seconds (default: 1 hour)
301 """
302 current_time = time.time()
303 expired_sessions = [
304 session_id
305 for session_id, session in self.sessions.items()
306 if current_time - session.get("created_at", 0) > max_age_seconds
307 ]
309 for session_id in expired_sessions:
310 del self.sessions[session_id]
311 logger.debug(f"Cleaned up expired session: {session_id}")
313 if expired_sessions:
314 logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")