Coverage for src/qdrant_loader_core/llm/providers/azure_openai.py: 80%

44 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:01 +0000

1from __future__ import annotations 

2 

3from typing import Any 

4from urllib.parse import urlparse 

5 

6try: 

7 from openai import AzureOpenAI # type: ignore 

8except Exception: # pragma: no cover - optional dependency surface 

9 AzureOpenAI = None # type: ignore 

10 

11from ...logging import LoggingConfig 

12from ..settings import LLMSettings 

13from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter 

14from .openai import OpenAIChat, OpenAIEmbeddings, _OpenAITokenCounter 

15 

16logger = LoggingConfig.get_logger(__name__) 

17 

18 

19def _host_of(url: str | None) -> str | None: 

20 if not url: 

21 return None 

22 try: 

23 return urlparse(url).hostname or None 

24 except Exception: 

25 return None 

26 

27 

28def _validate_azure_settings(settings: LLMSettings) -> None: 

29 base_url = settings.base_url or "" 

30 if "/openai/deployments" in base_url: 

31 raise ValueError( 

32 "Azure OpenAI base_url must be the resource root (e.g. https://<resource>.openai.azure.com). Do not include /openai/deployments/... in base_url." 

33 ) 

34 if not (settings.api_version and isinstance(settings.api_version, str)): 

35 raise ValueError( 

36 "Azure OpenAI requires api_version (e.g. '2024-05-01-preview') in global.llm.api_version" 

37 ) 

38 

39 

40class AzureOpenAIProvider(LLMProvider): 

41 def __init__(self, settings: LLMSettings): 

42 self._settings = settings 

43 _validate_azure_settings(settings) 

44 

45 self._base_host = _host_of(settings.base_url) 

46 if AzureOpenAI is None: 

47 self._client = None 

48 else: 

49 # Prefer explicit azure_endpoint in provider_options; fallback to base_url 

50 provider_opts = settings.provider_options or {} 

51 endpoint = provider_opts.get("azure_endpoint") or settings.base_url 

52 kwargs: dict[str, Any] = { 

53 "api_key": settings.api_key, 

54 "api_version": settings.api_version, 

55 } 

56 if endpoint: 

57 kwargs["azure_endpoint"] = endpoint 

58 self._client = AzureOpenAI( 

59 **{k: v for k, v in kwargs.items() if v is not None} 

60 ) 

61 

62 def embeddings(self) -> EmbeddingsClient: 

63 model = self._settings.models.get("embeddings", "") 

64 return OpenAIEmbeddings( 

65 self._client, model, self._base_host, provider_label="azure_openai" 

66 ) 

67 

68 def chat(self) -> ChatClient: 

69 model = self._settings.models.get("chat", "") 

70 return OpenAIChat( 

71 self._client, model, self._base_host, provider_label="azure_openai" 

72 ) 

73 

74 def tokenizer(self) -> TokenCounter: 

75 return _OpenAITokenCounter(self._settings.tokenizer)