Skip to content

Commit

Permalink
feat: add xAI model provider (langgenius#10272)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjlarry authored and jiangzhijie committed Nov 14, 2024
1 parent 15df8ec commit b7e5fc1
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 0 deletions.
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
63 changes: 63 additions & 0 deletions api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
model: grok-beta
label:
en_US: Grok beta
model_type: llm
features:
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"

- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"

- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"

- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"
37 changes: 37 additions & 0 deletions api/core/model_runtime/model_providers/x/llm/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections.abc import Generator
from typing import Optional, Union

from yarl import URL

from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel


class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)

@staticmethod
def _add_custom_parameters(credentials) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
credentials["mode"] = LLMMode.CHAT.value
credentials["function_calling_type"] = "tool_call"
25 changes: 25 additions & 0 deletions api/core/model_runtime/model_providers/x/x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class XAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex
38 changes: 38 additions & 0 deletions api/core/model_runtime/model_providers/x/x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
provider: x
label:
en_US: xAI
description:
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
icon_small:
en_US: x-ai-logo.svg
icon_large:
en_US: x-ai-logo.svg
help:
title:
en_US: Get your token from xAI
zh_Hans: 从 xAI 获取 token
url:
en_US: https://x.ai/api
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
default: https://api.x.ai/v1
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base
4 changes: 4 additions & 0 deletions api/tests/integration_tests/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ GPUSTACK_API_KEY=

# Gitee AI Credentials
GITEE_AI_API_KEY=

# xAI Credentials
XAI_API_KEY=
XAI_API_BASE=
Empty file.
204 changes: 204 additions & 0 deletions api/tests/integration_tests/model_runtime/x/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import os
from collections.abc import Generator

import pytest

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel

"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock


def test_predefined_models():
model = XAILargeLanguageModel()
model_schemas = model.predefined_models()

assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()

with pytest.raises(CredentialsValidateFailedError):
# model name to gpt-3.5-turbo because of mocking
model.validate_credentials(
model="gpt-3.5-turbo",
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
)

model.validate_credentials(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
)


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()

result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.0,
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=["How"],
stream=False,
user="foo",
)

assert isinstance(result, LLMResult)
assert len(result.message.content) > 0


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model_with_tools(setup_openai_mock):
model = XAILargeLanguageModel()

result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in London?",
),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
PromptMessageTool(
name="get_stock_price",
description="Get the current stock price",
parameters={
"type": "object",
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
"required": ["symbol"],
},
),
],
stream=False,
user="foo",
)

assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()

result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="foo",
)

assert isinstance(result, Generator)

for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0


def test_get_num_tokens():
model = XAILargeLanguageModel()

num_tokens = model.get_num_tokens(
model="grok-beta",
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)

assert num_tokens == 10

num_tokens = model.get_num_tokens(
model="grok-beta",
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
],
)

assert num_tokens == 77

0 comments on commit b7e5fc1

Please sign in to comment.