diff --git a/docs/docs/integrations/chat/oci_data_science.ipynb b/docs/docs/integrations/chat/oci_data_science.ipynb index 6196d2025c04b..61b0a41ae1f85 100644 --- a/docs/docs/integrations/chat/oci_data_science.ipynb +++ b/docs/docs/integrations/chat/oci_data_science.ipynb @@ -137,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -156,6 +156,10 @@ " \"temperature\": 0.2,\n", " \"max_tokens\": 512,\n", " }, # other model params...\n", + " default_headers={\n", + " \"route\": \"/v1/chat/completions\",\n", + " # other request headers ...\n", + " },\n", ")" ] }, diff --git a/libs/community/langchain_community/chat_models/oci_data_science.py b/libs/community/langchain_community/chat_models/oci_data_science.py index cdb181df897b3..8ac571adbd8d4 100644 --- a/libs/community/langchain_community/chat_models/oci_data_science.py +++ b/libs/community/langchain_community/chat_models/oci_data_science.py @@ -47,6 +47,7 @@ ) logger = logging.getLogger(__name__) +DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions" def _is_pydantic_class(obj: Any) -> bool: @@ -56,6 +57,13 @@ def _is_pydantic_class(obj: Any) -> bool: class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): """OCI Data Science Model Deployment chat model integration. + Prerequisite + The OCI Model Deployment plugins are installable only on + python version 3.9 and above. If you're working inside the notebook, + try installing the python 3.10 based conda pack and running the + following setup. + + Setup: Install ``oracle-ads`` and ``langchain-openai``. @@ -90,6 +98,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): Key init args — client params: auth: dict ADS auth dictionary for OCI authentication. + default_headers: Optional[Dict] + The headers to be added to the Model Deployment request. Instantiate: .. code-block:: python @@ -98,7 +108,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): chat = ChatOCIModelDeployment( endpoint="https://modeldeployment..oci.customer-oci.com//predict", - model="odsc-llm", + model="odsc-llm", # this is the default model name if deployed with AQUA streaming=True, max_retries=3, model_kwargs={ @@ -106,6 +116,10 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): "temperature": 0.2, # other model parameters ... }, + default_headers={ + "route": "/v1/chat/completions", + # other request headers ... + }, ) Invocation: @@ -288,6 +302,25 @@ def _default_params(self) -> Dict[str, Any]: "stream": self.streaming, } + def _headers( + self, is_async: Optional[bool] = False, body: Optional[dict] = None + ) -> Dict: + """Construct and return the headers for a request. + + Args: + is_async (bool, optional): Indicates if the request is asynchronous. + Defaults to `False`. + body (optional): The request body to be included in the headers if + the request is asynchronous. + + Returns: + Dict: A dictionary containing the appropriate headers for the request. + """ + return { + "route": DEFAULT_INFERENCE_ENDPOINT_CHAT, + **super()._headers(is_async=is_async, body=body), + } + def _generate( self, messages: List[BaseMessage], @@ -701,7 +734,7 @@ def _process_response(self, response_json: dict) -> ChatResult: for choice in choices: message = _convert_dict_to_message(choice["message"]) - generation_info = dict(finish_reason=choice.get("finish_reason")) + generation_info = {"finish_reason": choice.get("finish_reason")} if "logprobs" in choice: generation_info["logprobs"] = choice["logprobs"] diff --git a/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py b/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py index c998fbc0ec40f..e18a1a0847d0f 100644 --- a/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py +++ b/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py @@ -32,6 +32,7 @@ from langchain_community.utilities.requests import Requests logger = logging.getLogger(__name__) +DEFAULT_INFERENCE_ENDPOINT = "/v1/completions" DEFAULT_TIME_OUT = 300 @@ -81,6 +82,9 @@ class BaseOCIModelDeployment(Serializable): max_retries: int = 3 """Maximum number of retries to make when generating.""" + default_headers: Optional[Dict[str, Any]] = None + """The headers to be added to the Model Deployment request.""" + @model_validator(mode="before") @classmethod def validate_environment(cls, values: Dict) -> Dict: @@ -120,12 +124,12 @@ def _headers( Returns: Dict: A dictionary containing the appropriate headers for the request. """ + headers = self.default_headers or {} if is_async: signer = self.auth["signer"] _req = requests.Request("POST", self.endpoint, json=body) req = _req.prepare() req = signer(req) - headers = {} for key, value in req.headers.items(): headers[key] = value @@ -135,7 +139,7 @@ def _headers( ) return headers - return ( + headers.update( { "Content-Type": DEFAULT_CONTENT_TYPE_JSON, "enable-streaming": "true", @@ -147,6 +151,8 @@ def _headers( } ) + return headers + def completion_with_retry( self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: @@ -383,6 +389,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment): model="odsc-llm", streaming=True, model_kwargs={"frequency_penalty": 1.0}, + headers={ + "route": "/v1/completions", + # other request headers ... + } ) llm.invoke("tell me a joke.") @@ -426,7 +436,7 @@ def _construct_json_body(self, prompt: str, param:dict) -> dict: temperature: float = 0.2 """A non-negative float that tunes the degree of randomness in generation.""" - k: int = -1 + k: int = 50 """Number of most likely tokens to consider at each step.""" p: float = 0.75 @@ -472,6 +482,25 @@ def _identifying_params(self) -> Dict[str, Any]: **self._default_params, } + def _headers( + self, is_async: Optional[bool] = False, body: Optional[dict] = None + ) -> Dict: + """Construct and return the headers for a request. + + Args: + is_async (bool, optional): Indicates if the request is asynchronous. + Defaults to `False`. + body (optional): The request body to be included in the headers if + the request is asynchronous. + + Returns: + Dict: A dictionary containing the appropriate headers for the request. + """ + return { + "route": DEFAULT_INFERENCE_ENDPOINT, + **super()._headers(is_async=is_async, body=body), + } + def _generate( self, prompts: List[str], diff --git a/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py b/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py index f385f2ddb09a3..f5b7dae6432c2 100644 --- a/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py +++ b/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py @@ -19,6 +19,7 @@ CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" CONST_PROMPT = "This is a prompt." CONST_COMPLETION = "This is a completion." +CONST_COMPLETION_ROUTE = "/v1/chat/completions" CONST_COMPLETION_RESPONSE = { "id": "chat-123456789", "object": "chat.completion", @@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse: def test_invoke_vllm(*args: Any) -> None: """Tests invoking vLLM endpoint.""" llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert isinstance(output, AIMessage) assert output.content == CONST_COMPLETION @@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None: def test_invoke_tgi(*args: Any) -> None: """Tests invoking TGI endpoint using OpenAI Spec.""" llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert isinstance(output, AIMessage) assert output.content == CONST_COMPLETION @@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None: llm = ChatOCIModelDeploymentVLLM( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = None count = 0 for chunk in llm.stream(CONST_PROMPT): @@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None: llm = ChatOCIModelDeploymentVLLM( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE with mock.patch.object( llm, "_aiter_sse", diff --git a/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py b/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py index c87b05f12a12d..10555528013d1 100644 --- a/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py +++ b/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py @@ -18,6 +18,7 @@ CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" CONST_PROMPT = "This is a prompt." CONST_COMPLETION = "This is a completion." +CONST_COMPLETION_ROUTE = "/v1/completions" CONST_COMPLETION_RESPONSE = { "choices": [ { @@ -114,6 +115,7 @@ async def mocked_async_streaming_response( def test_invoke_vllm(*args: Any) -> None: """Tests invoking vLLM endpoint.""" llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert output == CONST_COMPLETION @@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None: llm = OCIModelDeploymentTGI( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = "" count = 0 for chunk in llm.stream(CONST_PROMPT): @@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None: llm = OCIModelDeploymentTGI( endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME ) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE output = llm.invoke(CONST_PROMPT) assert output == CONST_COMPLETION @@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None: llm = OCIModelDeploymentTGI( endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) + assert llm._headers().get("route") == CONST_COMPLETION_ROUTE with mock.patch.object( llm, "_aiter_sse",