Skip to content

Commit

Permalink
fix: openai env var load after init and before score also (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored Nov 21, 2023
1 parent 7ffc2f0 commit 13465f0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def __init__(self, api_key: str = NO_KEY):

def validate_api_key(self):
if self.openai_api_key == NO_KEY:
raise OpenAIKeyNotFound
os_env_key = os.getenv("OPENAI_API_KEY", NO_KEY)
if os_env_key != NO_KEY:
self.api_key = os_env_key
else:
raise OpenAIKeyNotFound


class AzureOpenAIEmbeddings(BaseAzureOpenAIEmbeddings, RagasEmbeddings):
Expand Down Expand Up @@ -73,7 +77,11 @@ def __init__(

def validate_api_key(self):
if self.openai_api_key == NO_KEY:
raise AzureOpenAIKeyNotFound
os_env_key = os.getenv("AZURE_OPENAI_API_KEY", NO_KEY)
if os_env_key != NO_KEY:
self.api_key = os_env_key
else:
raise AzureOpenAIKeyNotFound


@dataclass
Expand Down
12 changes: 10 additions & 2 deletions src/ragas/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ def validate_api_key(self):
if api_key != NO_KEY:
self._client.api_key = api_key
if self.llm.api_key == NO_KEY:
raise OpenAIKeyNotFound
os_env_key = os.getenv(self._api_key_env_var, NO_KEY)
if os_env_key != NO_KEY:
self.api_key = os_env_key
else:
raise OpenAIKeyNotFound


@dataclass
Expand All @@ -221,4 +225,8 @@ def _client_init(self):

def validate_api_key(self):
if self.llm.api_key == NO_KEY:
raise AzureOpenAIKeyNotFound
os_env_key = os.getenv(self._api_key_env_var, NO_KEY)
if os_env_key != NO_KEY:
self.api_key = os_env_key
else:
raise AzureOpenAIKeyNotFound
8 changes: 8 additions & 0 deletions tests/unit/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,11 @@ def test_validate_api_key_for_different_llms(
obj, api_key = factory(with_api_key=True)
assert obj.validate_api_key
assert obj.api_key == api_key

# assert loading key from environment variables after instantiation
if environ_key in os.environ:
os.environ.pop(environ_key)
obj = factory(with_api_key=False)
os.environ[environ_key] = "random-key-102848595"
assert obj.validate_api_key() is None
assert obj.api_key == "random-key-102848595"

0 comments on commit 13465f0

Please sign in to comment.