Coverage for src/qdrant_loader_core/llm/providers/azure_openai.py: 80%
44 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:01 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:01 +0000
1from __future__ import annotations
3from typing import Any
4from urllib.parse import urlparse
6try:
7 from openai import AzureOpenAI # type: ignore
8except Exception: # pragma: no cover - optional dependency surface
9 AzureOpenAI = None # type: ignore
11from ...logging import LoggingConfig
12from ..settings import LLMSettings
13from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter
14from .openai import OpenAIChat, OpenAIEmbeddings, _OpenAITokenCounter
16logger = LoggingConfig.get_logger(__name__)
19def _host_of(url: str | None) -> str | None:
20 if not url:
21 return None
22 try:
23 return urlparse(url).hostname or None
24 except Exception:
25 return None
28def _validate_azure_settings(settings: LLMSettings) -> None:
29 base_url = settings.base_url or ""
30 if "/openai/deployments" in base_url:
31 raise ValueError(
32 "Azure OpenAI base_url must be the resource root (e.g. https://<resource>.openai.azure.com). Do not include /openai/deployments/... in base_url."
33 )
34 if not (settings.api_version and isinstance(settings.api_version, str)):
35 raise ValueError(
36 "Azure OpenAI requires api_version (e.g. '2024-05-01-preview') in global.llm.api_version"
37 )
40class AzureOpenAIProvider(LLMProvider):
41 def __init__(self, settings: LLMSettings):
42 self._settings = settings
43 _validate_azure_settings(settings)
45 self._base_host = _host_of(settings.base_url)
46 if AzureOpenAI is None:
47 self._client = None
48 else:
49 # Prefer explicit azure_endpoint in provider_options; fallback to base_url
50 provider_opts = settings.provider_options or {}
51 endpoint = provider_opts.get("azure_endpoint") or settings.base_url
52 kwargs: dict[str, Any] = {
53 "api_key": settings.api_key,
54 "api_version": settings.api_version,
55 }
56 if endpoint:
57 kwargs["azure_endpoint"] = endpoint
58 self._client = AzureOpenAI(
59 **{k: v for k, v in kwargs.items() if v is not None}
60 )
62 def embeddings(self) -> EmbeddingsClient:
63 model = self._settings.models.get("embeddings", "")
64 return OpenAIEmbeddings(
65 self._client, model, self._base_host, provider_label="azure_openai"
66 )
68 def chat(self) -> ChatClient:
69 model = self._settings.models.get("chat", "")
70 return OpenAIChat(
71 self._client, model, self._base_host, provider_label="azure_openai"
72 )
74 def tokenizer(self) -> TokenCounter:
75 return _OpenAITokenCounter(self._settings.tokenizer)