Skip to content

Commit

Permalink
fix: add retry logic for OpenAI and Azure OpenAI (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored Nov 21, 2023
1 parent de8a1f0 commit 7ffc2f0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 4 deletions.
87 changes: 84 additions & 3 deletions src/ragas/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,104 @@
from __future__ import annotations

import asyncio
import logging
import os
import typing as t
from abc import abstractmethod
from dataclasses import dataclass, field

import openai
from langchain.adapters.openai import convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.schema import Generation, LLMResult
from openai import AsyncAzureOpenAI, AsyncClient, AsyncOpenAI
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from ragas.async_utils import run_async_tasks
from ragas.exceptions import AzureOpenAIKeyNotFound, OpenAIKeyNotFound
from ragas.llms.base import RagasLLM
from ragas.llms.langchain import _compute_token_usage_langchain
from ragas.utils import NO_KEY
from ragas.utils import NO_KEY, get_debug_mode

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain.prompts import ChatPromptTemplate

logger = logging.getLogger(__name__)

errors = [
openai.APITimeoutError,
openai.APIConnectionError,
openai.RateLimitError,
openai.APIConnectionError,
openai.InternalServerError,
]


def create_base_retry_decorator(
error_types: t.List[t.Type[BaseException]],
max_retries: int = 1,
run_manager: t.Optional[
t.Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> t.Callable[[t.Any], t.Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""

log_level = logging.WARNING if get_debug_mode() else logging.DEBUG
_logging = before_sleep_log(logger, log_level)

def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
logger.error(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None

min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(error)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance,
before_sleep=_before_sleep,
)


retry_decorator = create_base_retry_decorator(errors, max_retries=4)


class OpenAIBase(RagasLLM):
def __init__(self, model: str, _api_key_env_var: str) -> None:
def __init__(self, model: str, _api_key_env_var: str, timeout: int = 60) -> None:
self.model = model
self._api_key_env_var = _api_key_env_var
self.timeout = timeout

# api key
key_from_env = os.getenv(self._api_key_env_var, NO_KEY)
Expand Down Expand Up @@ -83,6 +158,7 @@ def generate(
llm_output = _compute_token_usage_langchain(llm_results)
return LLMResult(generations=generations, llm_output=llm_output)

@retry_decorator
async def agenerate(
self,
prompt: ChatPromptTemplate,
Expand Down Expand Up @@ -112,9 +188,13 @@ def __post_init__(self):
self._client_init()

def _client_init(self):
self._client = AsyncOpenAI(api_key=self.api_key)
self._client = AsyncOpenAI(api_key=self.api_key, timeout=self.timeout)

def validate_api_key(self):
# before validating, check if the api key is already set
api_key = os.getenv(self._api_key_env_var, NO_KEY)
if api_key != NO_KEY:
self._client.api_key = api_key
if self.llm.api_key == NO_KEY:
raise OpenAIKeyNotFound

Expand All @@ -136,6 +216,7 @@ def _client_init(self):
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
timeout=self.timeout,
)

def validate_api_key(self):
Expand Down
1 change: 0 additions & 1 deletion src/ragas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
@lru_cache(maxsize=1)
def get_debug_mode() -> bool:
if os.environ.get(DEBUG_ENV_VAR, str(False)).lower() == "true":
logging.basicConfig(level=logging.DEBUG)
return True
else:
return False

0 comments on commit 7ffc2f0

Please sign in to comment.