From 7ffc2f079b5f7990d93008058a0889483cf7d3f0 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Tue, 21 Nov 2023 10:31:46 +0530 Subject: [PATCH] fix: add retry logic for OpenAI and Azure OpenAI (#315) --- src/ragas/llms/openai.py | 87 ++++++++++++++++++++++++++++++++++++++-- src/ragas/utils.py | 1 - 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/ragas/llms/openai.py b/src/ragas/llms/openai.py index d7d521223..2fc4b98cd 100644 --- a/src/ragas/llms/openai.py +++ b/src/ragas/llms/openai.py @@ -1,29 +1,104 @@ from __future__ import annotations +import asyncio +import logging import os import typing as t from abc import abstractmethod from dataclasses import dataclass, field +import openai from langchain.adapters.openai import convert_message_to_dict +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.schema import Generation, LLMResult from openai import AsyncAzureOpenAI, AsyncClient, AsyncOpenAI +from tenacity import ( + RetryCallState, + before_sleep_log, + retry, + retry_base, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from ragas.async_utils import run_async_tasks from ragas.exceptions import AzureOpenAIKeyNotFound, OpenAIKeyNotFound from ragas.llms.base import RagasLLM from ragas.llms.langchain import _compute_token_usage_langchain -from ragas.utils import NO_KEY +from ragas.utils import NO_KEY, get_debug_mode if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks from langchain.prompts import ChatPromptTemplate +logger = logging.getLogger(__name__) + +errors = [ + openai.APITimeoutError, + openai.APIConnectionError, + openai.RateLimitError, + openai.APIConnectionError, + openai.InternalServerError, +] + + +def create_base_retry_decorator( + error_types: t.List[t.Type[BaseException]], + max_retries: int = 1, + run_manager: t.Optional[ + t.Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> t.Callable[[t.Any], t.Any]: + """Create a retry decorator for a given LLM and provided list of error types.""" + + log_level = logging.WARNING if get_debug_mode() else logging.DEBUG + _logging = before_sleep_log(logger, log_level) + + def _before_sleep(retry_state: RetryCallState) -> None: + _logging(retry_state) + if run_manager: + if isinstance(run_manager, AsyncCallbackManagerForLLMRun): + coro = run_manager.on_retry(retry_state) + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(coro) + else: + asyncio.run(coro) + except Exception as e: + logger.error(f"Error in on_retry: {e}") + else: + run_manager.on_retry(retry_state) + return None + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + retry_instance: "retry_base" = retry_if_exception_type(error_types[0]) + for error in error_types[1:]: + retry_instance = retry_instance | retry_if_exception_type(error) + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=retry_instance, + before_sleep=_before_sleep, + ) + + +retry_decorator = create_base_retry_decorator(errors, max_retries=4) + class OpenAIBase(RagasLLM): - def __init__(self, model: str, _api_key_env_var: str) -> None: + def __init__(self, model: str, _api_key_env_var: str, timeout: int = 60) -> None: self.model = model self._api_key_env_var = _api_key_env_var + self.timeout = timeout # api key key_from_env = os.getenv(self._api_key_env_var, NO_KEY) @@ -83,6 +158,7 @@ def generate( llm_output = _compute_token_usage_langchain(llm_results) return LLMResult(generations=generations, llm_output=llm_output) + @retry_decorator async def agenerate( self, prompt: ChatPromptTemplate, @@ -112,9 +188,13 @@ def __post_init__(self): self._client_init() def _client_init(self): - self._client = AsyncOpenAI(api_key=self.api_key) + self._client = AsyncOpenAI(api_key=self.api_key, timeout=self.timeout) def validate_api_key(self): + # before validating, check if the api key is already set + api_key = os.getenv(self._api_key_env_var, NO_KEY) + if api_key != NO_KEY: + self._client.api_key = api_key if self.llm.api_key == NO_KEY: raise OpenAIKeyNotFound @@ -136,6 +216,7 @@ def _client_init(self): api_version=self.api_version, azure_endpoint=self.azure_endpoint, api_key=self.api_key, + timeout=self.timeout, ) def validate_api_key(self): diff --git a/src/ragas/utils.py b/src/ragas/utils.py index c5db04ba9..82a1616cd 100644 --- a/src/ragas/utils.py +++ b/src/ragas/utils.py @@ -12,7 +12,6 @@ @lru_cache(maxsize=1) def get_debug_mode() -> bool: if os.environ.get(DEBUG_ENV_VAR, str(False)).lower() == "true": - logging.basicConfig(level=logging.DEBUG) return True else: return False