From 13465f01224c3c2f0ba3dac183d708a545dcecd5 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Tue, 21 Nov 2023 10:48:01 +0530 Subject: [PATCH] fix: openai env var load after init and before score also (#316) --- src/ragas/embeddings/base.py | 12 ++++++++++-- src/ragas/llms/openai.py | 12 ++++++++++-- tests/unit/test_llm.py | 8 ++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/ragas/embeddings/base.py b/src/ragas/embeddings/base.py index 634a7da15..19ee8c38c 100644 --- a/src/ragas/embeddings/base.py +++ b/src/ragas/embeddings/base.py @@ -40,7 +40,11 @@ def __init__(self, api_key: str = NO_KEY): def validate_api_key(self): if self.openai_api_key == NO_KEY: - raise OpenAIKeyNotFound + os_env_key = os.getenv("OPENAI_API_KEY", NO_KEY) + if os_env_key != NO_KEY: + self.api_key = os_env_key + else: + raise OpenAIKeyNotFound class AzureOpenAIEmbeddings(BaseAzureOpenAIEmbeddings, RagasEmbeddings): @@ -73,7 +77,11 @@ def __init__( def validate_api_key(self): if self.openai_api_key == NO_KEY: - raise AzureOpenAIKeyNotFound + os_env_key = os.getenv("AZURE_OPENAI_API_KEY", NO_KEY) + if os_env_key != NO_KEY: + self.api_key = os_env_key + else: + raise AzureOpenAIKeyNotFound @dataclass diff --git a/src/ragas/llms/openai.py b/src/ragas/llms/openai.py index 2fc4b98cd..2ef678b5b 100644 --- a/src/ragas/llms/openai.py +++ b/src/ragas/llms/openai.py @@ -196,7 +196,11 @@ def validate_api_key(self): if api_key != NO_KEY: self._client.api_key = api_key if self.llm.api_key == NO_KEY: - raise OpenAIKeyNotFound + os_env_key = os.getenv(self._api_key_env_var, NO_KEY) + if os_env_key != NO_KEY: + self.api_key = os_env_key + else: + raise OpenAIKeyNotFound @dataclass @@ -221,4 +225,8 @@ def _client_init(self): def validate_api_key(self): if self.llm.api_key == NO_KEY: - raise AzureOpenAIKeyNotFound + os_env_key = os.getenv(self._api_key_env_var, NO_KEY) + if os_env_key != NO_KEY: + self.api_key = os_env_key + else: + raise AzureOpenAIKeyNotFound diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index d414e4f21..903fe23f0 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -136,3 +136,11 @@ def test_validate_api_key_for_different_llms( obj, api_key = factory(with_api_key=True) assert obj.validate_api_key assert obj.api_key == api_key + + # assert loading key from environment variables after instantiation + if environ_key in os.environ: + os.environ.pop(environ_key) + obj = factory(with_api_key=False) + os.environ[environ_key] = "random-key-102848595" + assert obj.validate_api_key() is None + assert obj.api_key == "random-key-102848595"