Coverage for src/qdrant_loader/core/qdrant_manager.py: 82%

123 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-04 05:50 +0000

1import asyncio 

2from typing import cast 

3from urllib.parse import urlparse 

4 

5from qdrant_client import QdrantClient 

6from qdrant_client.http import models 

7from qdrant_client.http.models import ( 

8 Distance, 

9 VectorParams, 

10) 

11 

12from ..config import Settings, get_global_config, get_settings 

13from ..utils.logging import LoggingConfig 

14 

15logger = LoggingConfig.get_logger(__name__) 

16 

17 

18class QdrantConnectionError(Exception): 

19 """Custom exception for Qdrant connection errors.""" 

20 

21 def __init__( 

22 self, message: str, original_error: str | None = None, url: str | None = None 

23 ): 

24 self.message = message 

25 self.original_error = original_error 

26 self.url = url 

27 super().__init__(self.message) 

28 

29 

30class QdrantManager: 

31 def __init__(self, settings: Settings | None = None): 

32 """Initialize the qDrant manager. 

33 

34 Args: 

35 settings: The application settings 

36 """ 

37 self.settings = settings or get_settings() 

38 self.client = None 

39 self.collection_name = self.settings.qdrant_collection_name 

40 self.logger = LoggingConfig.get_logger(__name__) 

41 self.batch_size = get_global_config().embedding.batch_size 

42 self.connect() 

43 

44 def _is_api_key_present(self) -> bool: 

45 """ 

46 Check if a valid API key is present. 

47 Returns True if the API key is a non-empty string that is not 'None' or 'null'. 

48 """ 

49 api_key = self.settings.qdrant_api_key 

50 if not api_key: # Catches None, empty string, etc. 

51 return False 

52 return api_key.lower() not in ["none", "null"] 

53 

54 def connect(self) -> None: 

55 """Establish connection to qDrant server.""" 

56 try: 

57 # Ensure HTTPS is used when API key is present, but only for non-local URLs 

58 url = self.settings.qdrant_url 

59 api_key = ( 

60 self.settings.qdrant_api_key if self._is_api_key_present() else None 

61 ) 

62 

63 if api_key: 

64 parsed_url = urlparse(url) 

65 # Only force HTTPS for non-local URLs 

66 if parsed_url.scheme != "https" and not any( 

67 host in parsed_url.netloc for host in ["localhost", "127.0.0.1"] 

68 ): 

69 url = url.replace("http://", "https://", 1) 

70 self.logger.warning("Forcing HTTPS connection due to API key usage") 

71 

72 try: 

73 self.client = QdrantClient( 

74 url=url, 

75 api_key=api_key, 

76 timeout=60, # 60 seconds timeout 

77 ) 

78 self.logger.debug("Successfully connected to qDrant") 

79 except Exception as e: 

80 raise QdrantConnectionError( 

81 "Failed to connect to qDrant: Connection error", 

82 original_error=str(e), 

83 url=url, 

84 ) from e 

85 

86 except Exception as e: 

87 raise QdrantConnectionError( 

88 "Failed to connect to qDrant: Unexpected error", 

89 original_error=str(e), 

90 url=url, 

91 ) from e 

92 

93 def _ensure_client_connected(self) -> QdrantClient: 

94 """Ensure the client is connected before performing operations.""" 

95 if self.client is None: 

96 raise QdrantConnectionError( 

97 "Qdrant client is not connected. Please call connect() first." 

98 ) 

99 return cast(QdrantClient, self.client) 

100 

101 def create_collection(self) -> None: 

102 """Create a new collection if it doesn't exist.""" 

103 try: 

104 client = self._ensure_client_connected() 

105 # Check if collection already exists 

106 collections = client.get_collections() 

107 if any(c.name == self.collection_name for c in collections.collections): 

108 self.logger.info(f"Collection {self.collection_name} already exists") 

109 return 

110 

111 # Get vector size from configuration 

112 vector_size = get_global_config().embedding.vector_size 

113 if not vector_size: 

114 self.logger.warning( 

115 "No vector_size specified in config, defaulting to 1536" 

116 ) 

117 vector_size = 1536 

118 

119 # Create collection with basic configuration 

120 client.create_collection( 

121 collection_name=self.collection_name, 

122 vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), 

123 ) 

124 

125 # Create index for document_id field 

126 client.create_payload_index( 

127 collection_name=self.collection_name, 

128 field_name="document_id", 

129 field_schema={"type": "keyword"}, # type: ignore 

130 ) 

131 

132 self.logger.debug(f"Collection {self.collection_name} created successfully") 

133 except Exception as e: 

134 self.logger.error("Failed to create collection", error=str(e)) 

135 raise 

136 

137 async def upsert_points(self, points: list[models.PointStruct]) -> None: 

138 """Upsert points into the collection. 

139 

140 Args: 

141 points: List of points to upsert 

142 """ 

143 self.logger.debug( 

144 "Upserting points", 

145 extra={"point_count": len(points), "collection": self.collection_name}, 

146 ) 

147 

148 try: 

149 client = self._ensure_client_connected() 

150 await asyncio.to_thread( 

151 client.upsert, collection_name=self.collection_name, points=points 

152 ) 

153 self.logger.debug( 

154 "Successfully upserted points", 

155 extra={"point_count": len(points), "collection": self.collection_name}, 

156 ) 

157 except Exception as e: 

158 self.logger.error( 

159 "Failed to upsert points", 

160 extra={ 

161 "error": str(e), 

162 "point_count": len(points), 

163 "collection": self.collection_name, 

164 }, 

165 ) 

166 raise 

167 

168 def search( 

169 self, query_vector: list[float], limit: int = 5 

170 ) -> list[models.ScoredPoint]: 

171 """Search for similar vectors in the collection.""" 

172 try: 

173 client = self._ensure_client_connected() 

174 search_result = client.search( 

175 collection_name=self.collection_name, 

176 query_vector=query_vector, 

177 limit=limit, 

178 ) 

179 return search_result 

180 except Exception as e: 

181 logger.error("Failed to search collection", error=str(e)) 

182 raise 

183 

184 def search_with_project_filter( 

185 self, query_vector: list[float], project_ids: list[str], limit: int = 5 

186 ) -> list[models.ScoredPoint]: 

187 """Search for similar vectors in the collection with project filtering. 

188 

189 Args: 

190 query_vector: Query vector for similarity search 

191 project_ids: List of project IDs to filter by 

192 limit: Maximum number of results to return 

193 

194 Returns: 

195 List of scored points matching the query and project filter 

196 """ 

197 try: 

198 client = self._ensure_client_connected() 

199 

200 # Build project filter 

201 project_filter = models.Filter( 

202 must=[ 

203 models.FieldCondition( 

204 key="project_id", match=models.MatchAny(any=project_ids) 

205 ) 

206 ] 

207 ) 

208 

209 search_result = client.search( 

210 collection_name=self.collection_name, 

211 query_vector=query_vector, 

212 query_filter=project_filter, 

213 limit=limit, 

214 ) 

215 return search_result 

216 except Exception as e: 

217 logger.error( 

218 "Failed to search collection with project filter", 

219 error=str(e), 

220 project_ids=project_ids, 

221 ) 

222 raise 

223 

224 def get_project_collections(self) -> dict[str, str]: 

225 """Get mapping of project IDs to their collection names. 

226 

227 Returns: 

228 Dictionary mapping project_id to collection_name 

229 """ 

230 try: 

231 client = self._ensure_client_connected() 

232 

233 # Scroll through all points to get unique project-collection mappings 

234 scroll_result = client.scroll( 

235 collection_name=self.collection_name, 

236 limit=10000, # Large limit to get all unique projects 

237 with_payload=True, 

238 with_vectors=False, 

239 ) 

240 

241 project_collections = {} 

242 for point in scroll_result[0]: 

243 if point.payload: 

244 project_id = point.payload.get("project_id") 

245 collection_name = point.payload.get("collection_name") 

246 if project_id and collection_name: 

247 project_collections[project_id] = collection_name 

248 

249 return project_collections 

250 except Exception as e: 

251 logger.error("Failed to get project collections", error=str(e)) 

252 raise 

253 

254 def delete_collection(self) -> None: 

255 """Delete the collection.""" 

256 try: 

257 client = self._ensure_client_connected() 

258 client.delete_collection(collection_name=self.collection_name) 

259 logger.debug("Collection deleted", collection=self.collection_name) 

260 except Exception as e: 

261 logger.error("Failed to delete collection", error=str(e)) 

262 raise 

263 

264 async def delete_points_by_document_id(self, document_ids: list[str]) -> None: 

265 """Delete points from the collection by document ID. 

266 

267 Args: 

268 document_ids: List of document IDs to delete 

269 """ 

270 self.logger.debug( 

271 "Deleting points by document ID", 

272 extra={ 

273 "document_count": len(document_ids), 

274 "collection": self.collection_name, 

275 }, 

276 ) 

277 

278 try: 

279 client = self._ensure_client_connected() 

280 await asyncio.to_thread( 

281 client.delete, 

282 collection_name=self.collection_name, 

283 points_selector=models.Filter( 

284 must=[ 

285 models.FieldCondition( 

286 key="document_id", match=models.MatchAny(any=document_ids) 

287 ) 

288 ] 

289 ), 

290 ) 

291 self.logger.debug( 

292 "Successfully deleted points", 

293 extra={ 

294 "document_count": len(document_ids), 

295 "collection": self.collection_name, 

296 }, 

297 ) 

298 except Exception as e: 

299 self.logger.error( 

300 "Failed to delete points", 

301 extra={ 

302 "error": str(e), 

303 "document_count": len(document_ids), 

304 "collection": self.collection_name, 

305 }, 

306 ) 

307 raise