Coverage for src/qdrant_loader_mcp_server/transport/http_handler.py: 81%

129 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-13 09:20 +0000

1"""HTTP Transport Handler for MCP Protocol 2025-06-18.""" 

2 

3import asyncio 

4import json 

5import time 

6from typing import Any 

7 

8from fastapi import FastAPI, HTTPException, Request, Response 

9from fastapi.middleware.cors import CORSMiddleware 

10from fastapi.responses import StreamingResponse 

11 

12from ..utils.logging import LoggingConfig 

13 

14logger = LoggingConfig.get_logger(__name__) 

15 

16 

17class HTTPTransportHandler: 

18 """HTTP Transport Handler for MCP Protocol with SSE streaming support.""" 

19 

20 def __init__(self, mcp_handler, host: str = "127.0.0.1", port: int = 8080): 

21 """Initialize HTTP transport handler. 

22 

23 Args: 

24 mcp_handler: The MCP handler instance to process requests 

25 host: Host to bind to (default: 127.0.0.1 for security) 

26 port: Port to bind to (default: 8080) 

27 """ 

28 self.mcp_handler = mcp_handler 

29 self.host = host 

30 self.port = port 

31 self.app = FastAPI( 

32 title="QDrant Loader MCP Server", 

33 description="HTTP transport for Model Context Protocol", 

34 version="1.0.0", 

35 ) 

36 self.sessions: dict[str, dict[str, Any]] = {} 

37 # Track in-flight requests to support graceful shutdown 

38 self._inflight_requests: int = 0 

39 self._inflight_non_stream_requests: int = 0 

40 self._counter_lock = asyncio.Lock() 

41 self._setup_middleware() 

42 self._setup_routes() 

43 logger.info(f"HTTP transport handler initialized on {host}:{port}") 

44 

45 def _setup_middleware(self): 

46 """Setup FastAPI middleware for CORS and security.""" 

47 # Add CORS middleware for browser clients 

48 self.app.add_middleware( 

49 CORSMiddleware, 

50 allow_origin_regex=r"https?://(localhost|127\.0\.0\.1)(:[0-9]+)?", 

51 allow_credentials=True, 

52 allow_methods=["GET", "POST", "OPTIONS"], 

53 allow_headers=["*"], 

54 ) 

55 

56 @self.app.middleware("http") 

57 async def _track_inflight_requests(request: Request, call_next): 

58 """Track in-flight requests to enable graceful shutdown. 

59 

60 We treat GET /mcp as a streaming (SSE) request and track it separately so 

61 shutdown logic can prioritize draining non-streaming requests first. 

62 """ 

63 is_stream = request.method.upper() == "GET" and request.url.path == "/mcp" 

64 async with self._counter_lock: 

65 self._inflight_requests += 1 

66 if not is_stream: 

67 self._inflight_non_stream_requests += 1 

68 

69 try: 

70 response = await call_next(request) 

71 return response 

72 finally: 

73 async with self._counter_lock: 

74 # Guard against going below zero in unexpected edge cases 

75 self._inflight_requests = max(0, self._inflight_requests - 1) 

76 if not is_stream: 

77 self._inflight_non_stream_requests = max( 

78 0, self._inflight_non_stream_requests - 1 

79 ) 

80 

81 def _setup_routes(self): 

82 """Setup FastAPI routes for MCP endpoints.""" 

83 

84 @self.app.post("/mcp") 

85 async def handle_mcp_post(request: Request): 

86 """Handle client-to-server messages via HTTP POST.""" 

87 logger.debug("Received POST request to /mcp") 

88 return await self._handle_post_request(request) 

89 

90 @self.app.get("/mcp") 

91 async def handle_mcp_get(request: Request): 

92 """Handle server-to-client streaming via SSE.""" 

93 logger.debug("Received GET request to /mcp for SSE streaming") 

94 return await self._handle_get_request(request) 

95 

96 @self.app.options("/mcp") 

97 async def handle_mcp_options(): 

98 """Handle CORS preflight requests for /mcp endpoint.""" 

99 return Response( 

100 status_code=200, 

101 headers={ 

102 "Access-Control-Allow-Origin": "*", 

103 "Access-Control-Allow-Methods": "GET, POST, OPTIONS", 

104 "Access-Control-Allow-Headers": "*", 

105 "Access-Control-Allow-Credentials": "true", 

106 }, 

107 ) 

108 

109 @self.app.get("/health") 

110 async def health_check(): 

111 """Health check endpoint.""" 

112 return {"status": "healthy", "transport": "http", "protocol": "mcp"} 

113 

114 async def _handle_post_request(self, request: Request) -> dict[str, Any]: 

115 """Process MCP messages from HTTP POST requests. 

116 

117 Args: 

118 request: FastAPI request object 

119 

120 Returns: 

121 MCP response dictionary 

122 """ 

123 try: 

124 # Security: Validate Origin header (DNS rebinding protection) 

125 origin = request.headers.get("origin") 

126 if not self._validate_origin(origin): 

127 logger.warning(f"Invalid origin header: {origin}") 

128 raise HTTPException(status_code=403, detail="Invalid origin") 

129 

130 # Protocol version validation (optional with backward compatibility) 

131 protocol_version = request.headers.get("mcp-protocol-version") 

132 if not self._validate_protocol_version(protocol_version): 

133 logger.warning(f"Unsupported protocol version: {protocol_version}") 

134 # Continue with warning but don't reject for backward compatibility 

135 

136 # Session management 

137 session_id = request.headers.get("mcp-session-id") 

138 if not session_id: 

139 session_id = f"session_{int(time.time() * 1000)}" 

140 logger.debug(f"Generated new session ID: {session_id}") 

141 

142 # Process MCP request 

143 mcp_request = await request.json() 

144 logger.debug( 

145 f"Processing MCP request: {mcp_request.get('method', 'unknown')}" 

146 ) 

147 

148 # Add headers context to request processing 

149 response = await self.mcp_handler.handle_request( 

150 mcp_request, headers=dict(request.headers) 

151 ) 

152 

153 # Store response for SSE streaming if needed 

154 if session_id not in self.sessions: 

155 self.sessions[session_id] = {"messages": [], "created_at": time.time()} 

156 

157 # Store any server-initiated messages for this session 

158 # (for future elicitation support) 

159 

160 logger.debug("Successfully processed MCP request, returning response") 

161 return response 

162 

163 except HTTPException: 

164 # Re-raise HTTPException so FastAPI can handle it properly 

165 raise 

166 except json.JSONDecodeError as e: 

167 logger.error(f"JSON decode error: {e}") 

168 return { 

169 "jsonrpc": "2.0", 

170 "error": {"code": -32700, "message": "Invalid JSON in request"}, 

171 } 

172 except Exception as e: 

173 logger.error(f"Error processing MCP request: {e}", exc_info=True) 

174 return { 

175 "jsonrpc": "2.0", 

176 "error": {"code": -32603, "message": "Internal server error"}, 

177 } 

178 

179 async def _handle_get_request(self, request: Request) -> StreamingResponse: 

180 """Handle SSE streaming for server-to-client messages. 

181 

182 Args: 

183 request: FastAPI request object 

184 

185 Returns: 

186 StreamingResponse with SSE events 

187 """ 

188 session_id = request.headers.get("mcp-session-id") 

189 logger.debug(f"Setting up SSE stream for session: {session_id}") 

190 

191 async def event_stream(): 

192 """Generate SSE events for the session.""" 

193 try: 

194 while True: 

195 # Check for new messages in session 

196 if session_id and session_id in self.sessions: 

197 session = self.sessions[session_id] 

198 if session.get("messages"): 

199 for message in session["messages"]: 

200 logger.debug(f"Sending SSE message: {message}") 

201 yield f"data: {json.dumps(message)}\n\n" 

202 session["messages"] = [] # Clear sent messages 

203 

204 # Send heartbeat to keep connection alive 

205 yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n" 

206 

207 await asyncio.sleep(1.0) # Poll interval (1 second) 

208 

209 except asyncio.CancelledError: 

210 logger.debug(f"SSE stream cancelled for session: {session_id}") 

211 raise 

212 except Exception as e: 

213 logger.error(f"Error in SSE stream: {e}", exc_info=True) 

214 yield f"data: {json.dumps({'type': 'error', 'message': 'Stream processing error'})}\n\n" 

215 

216 return StreamingResponse( 

217 event_stream(), 

218 media_type="text/event-stream", 

219 headers={ 

220 "Cache-Control": "no-cache", 

221 "Connection": "keep-alive", 

222 "Access-Control-Allow-Origin": "*", 

223 "Access-Control-Allow-Headers": "*", 

224 "X-Accel-Buffering": "no", # Disable nginx buffering 

225 }, 

226 ) 

227 

228 def get_inflight_request_counts(self) -> dict[str, int]: 

229 """Return current in-flight request counters. 

230 

231 - total: All requests currently being processed 

232 - non_streaming: Requests that are not long-lived streams (e.g., SSE) 

233 """ 

234 return { 

235 "total": self._inflight_requests, 

236 "non_streaming": self._inflight_non_stream_requests, 

237 } 

238 

239 def has_inflight_non_streaming(self) -> bool: 

240 """Whether there are non-streaming in-flight requests.""" 

241 return self._inflight_non_stream_requests > 0 

242 

243 def _validate_origin(self, origin: str | None) -> bool: 

244 """Validate Origin header to prevent DNS rebinding attacks. 

245 

246 Args: 

247 origin: Origin header value 

248 

249 Returns: 

250 True if origin is valid, False otherwise 

251 """ 

252 if not origin: 

253 # Allow requests without Origin (non-browser clients) 

254 return True 

255 

256 # Only allow localhost origins for security 

257 allowed_origins = [ 

258 "http://localhost", 

259 "https://localhost", 

260 "http://127.0.0.1", 

261 "https://127.0.0.1", 

262 ] 

263 

264 # Check if origin starts with any allowed origin 

265 # (to handle different ports like http://localhost:3000) 

266 return any(origin.startswith(allowed) for allowed in allowed_origins) 

267 

268 def _validate_protocol_version(self, version: str | None) -> bool: 

269 """Validate MCP protocol version header. 

270 

271 Args: 

272 version: Protocol version from header 

273 

274 Returns: 

275 True if version is supported, False otherwise 

276 """ 

277 if not version: 

278 # Allow for backward compatibility 

279 return True 

280 

281 # Supported protocol versions 

282 supported_versions = ["2025-06-18", "2025-03-26", "2024-11-05"] 

283 return version in supported_versions 

284 

285 def add_session_message(self, session_id: str, message: dict[str, Any]): 

286 """Add a message to a session for SSE streaming. 

287 

288 Args: 

289 session_id: Session identifier 

290 message: Message to add to session queue 

291 """ 

292 if session_id in self.sessions: 

293 self.sessions[session_id]["messages"].append(message) 

294 logger.debug(f"Added message to session {session_id}: {message}") 

295 

296 def cleanup_sessions(self, max_age_seconds: int = 3600): 

297 """Clean up old sessions. 

298 

299 Args: 

300 max_age_seconds: Maximum age of sessions in seconds (default: 1 hour) 

301 """ 

302 current_time = time.time() 

303 expired_sessions = [ 

304 session_id 

305 for session_id, session in self.sessions.items() 

306 if current_time - session.get("created_at", 0) > max_age_seconds 

307 ] 

308 

309 for session_id in expired_sessions: 

310 del self.sessions[session_id] 

311 logger.debug(f"Cleaned up expired session: {session_id}") 

312 

313 if expired_sessions: 

314 logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")