Skip to content

Commit

Permalink
fix(azure/): support passing headers to azure openai endpoints
Browse files Browse the repository at this point in the history
Fixes #6217
  • Loading branch information
krrishdholakia committed Dec 12, 2024
1 parent aa7f416 commit 5d9db82
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
7 changes: 5 additions & 2 deletions litellm/llms/azure/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ def completion( # noqa: PLR0915
headers: Optional[dict] = None,
client=None,
):
super().completion()
if headers:
optional_params["extra_headers"] = headers
try:
if model is None or messages is None:
raise AzureOpenAIError(
Expand Down Expand Up @@ -851,8 +852,10 @@ def embedding(
max_retries: Optional[int] = None,
client=None,
aembedding=None,
headers: Optional[dict] = None,
) -> litellm.EmbeddingResponse:
super().embedding()
if headers:
optional_params["extra_headers"] = headers
if self._client_session is None:
self._client_session = self.create_client_session()
try:
Expand Down
5 changes: 2 additions & 3 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3171,6 +3171,7 @@ def embedding( # noqa: PLR0915
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
headers = kwargs.get("headers", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
Expand Down Expand Up @@ -3281,9 +3282,6 @@ def embedding( # noqa: PLR0915
"azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN")

if extra_headers is not None:
optional_params["extra_headers"] = extra_headers

api_key = (
api_key
or litellm.api_key
Expand Down Expand Up @@ -3311,6 +3309,7 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
max_retries=max_retries,
headers=headers or extra_headers,
)
elif (
model in litellm.open_ai_embedding_models
Expand Down
26 changes: 18 additions & 8 deletions tests/llm_translation/test_azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def test_process_azure_headers_with_dict_input():
({"prompt": "Hello world"}, "image_generation"),
],
)
def test_azure_extra_headers(input, call_type):
@pytest.mark.parametrize(
"header_value",
[
"headers",
"extra_headers",
],
)
def test_azure_extra_headers(input, call_type, header_value):
from litellm import embedding, image_generation

http_client = Client()
Expand All @@ -128,18 +135,21 @@ def test_azure_extra_headers(input, call_type):
func = embedding
elif call_type == "image_generation":
func = image_generation
response = func(
model="azure/chatgpt-v-2",
api_base="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2023-07-01-preview",
api_key="my-azure-api-key",
extra_headers={

data = {
"model": "azure/chatgpt-v-2",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
"api_version": "2023-07-01-preview",
"api_key": "my-azure-api-key",
header_value: {
"Authorization": "my-bad-key",
"Ocp-Apim-Subscription-Key": "hello-world-testing",
},
**input,
)
}
response = func(**data)
print(response)

except Exception as e:
print(e)

Expand Down

0 comments on commit 5d9db82

Please sign in to comment.