Coverage for src/qdrant_loader_core/llm/settings.py: 93%

67 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:01 +0000

1from __future__ import annotations 

2 

3import warnings 

4from collections.abc import Mapping 

5from dataclasses import dataclass 

6from typing import Any 

7from urllib.parse import urlparse 

8 

9 

10@dataclass 

11class RequestPolicy: 

12 timeout_s: float = 30.0 

13 max_retries: int = 3 

14 backoff_s_min: float = 1.0 

15 backoff_s_max: float = 30.0 

16 

17 

18@dataclass 

19class RateLimitPolicy: 

20 rpm: int | None = None 

21 tpm: int | None = None 

22 concurrency: int = 5 

23 

24 

25@dataclass 

26class EmbeddingPolicy: 

27 vector_size: int | None = None 

28 

29 

30@dataclass 

31class LLMSettings: 

32 provider: str 

33 base_url: str | None 

34 api_key: str | None 

35 headers: dict[str, str] | None 

36 models: dict[str, str] 

37 tokenizer: str 

38 request: RequestPolicy 

39 rate_limits: RateLimitPolicy 

40 embeddings: EmbeddingPolicy 

41 api_version: str | None = None 

42 provider_options: dict[str, Any] | None = None 

43 

44 @staticmethod 

45 def from_global_config(global_data: Mapping[str, Any]) -> LLMSettings: 

46 """Construct settings from a parsed global configuration dict. 

47 

48 Supports two schemas: 

49 - New: global.llm 

50 - Legacy: global.embedding and file_conversion.markitdown 

51 """ 

52 llm = (global_data or {}).get("llm") or {} 

53 if llm: 

54 return LLMSettings( 

55 provider=str(llm.get("provider")), 

56 base_url=llm.get("base_url"), 

57 api_key=llm.get("api_key"), 

58 api_version=llm.get("api_version"), 

59 headers=dict(llm.get("headers") or {}), 

60 models=dict(llm.get("models") or {}), 

61 tokenizer=str(llm.get("tokenizer", "none")), 

62 request=RequestPolicy(**(llm.get("request") or {})), 

63 rate_limits=RateLimitPolicy(**(llm.get("rate_limits") or {})), 

64 embeddings=EmbeddingPolicy(**(llm.get("embeddings") or {})), 

65 provider_options=dict(llm.get("provider_options") or {}), 

66 ) 

67 

68 # Legacy mapping 

69 embedding = (global_data or {}).get("embedding") or {} 

70 file_conv = (global_data or {}).get("file_conversion") or {} 

71 markit = ( 

72 (file_conv.get("markitdown") or {}) if isinstance(file_conv, dict) else {} 

73 ) 

74 

75 endpoint = embedding.get("endpoint") 

76 # Detect Azure OpenAI in legacy endpoint to set provider accordingly 

77 endpoint_l = (endpoint or "").lower() if isinstance(endpoint, str) else "" 

78 host: str | None = None 

79 if endpoint_l: 

80 try: 

81 host = urlparse(endpoint_l).hostname or None 

82 except Exception: 

83 host = None 

84 is_azure = False 

85 if host: 

86 host_l = host.lower() 

87 is_azure = ( 

88 host_l == "openai.azure.com" 

89 or host_l.endswith(".openai.azure.com") 

90 or host_l == "cognitiveservices.azure.com" 

91 or host_l.endswith(".cognitiveservices.azure.com") 

92 ) 

93 if is_azure: 

94 provider = "azure_openai" 

95 elif "openai" in endpoint_l: 

96 provider = "openai" 

97 else: 

98 provider = "openai_compat" 

99 models = { 

100 "embeddings": embedding.get("model"), 

101 } 

102 if isinstance(markit.get("llm_model"), str): 

103 models["chat"] = markit.get("llm_model") 

104 

105 # Emit deprecation warnings when relying on legacy fields 

106 try: 

107 if embedding or markit: 

108 warnings.warn( 

109 ( 

110 "Using legacy configuration fields is deprecated. " 

111 "Please migrate to 'global.llm' (see docs: configuration reference). " 

112 "Mapped from: global.embedding.* and/or file_conversion.markitdown.*" 

113 ), 

114 category=DeprecationWarning, 

115 stacklevel=2, 

116 ) 

117 except Exception: 

118 # Best-effort warning; never break mapping 

119 pass 

120 

121 return LLMSettings( 

122 provider=provider, 

123 base_url=endpoint, 

124 api_key=embedding.get("api_key"), 

125 api_version=None, 

126 headers=None, 

127 models=models, 

128 tokenizer=str(embedding.get("tokenizer", "none")), 

129 request=RequestPolicy(), 

130 rate_limits=RateLimitPolicy(), 

131 embeddings=EmbeddingPolicy(vector_size=embedding.get("vector_size")), 

132 provider_options=None, 

133 )