Skip to content

Commit

Permalink
refactor: removed unused generate method from the LLM interface.
Browse files Browse the repository at this point in the history
All PlanAI functions use the chat completions interface and generate was just leading to unnecessary code maintainance.
  • Loading branch information
provos committed Nov 18, 2024
1 parent 88c90ab commit 4386576
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 83 deletions.
26 changes: 0 additions & 26 deletions src/planai/llm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,32 +116,6 @@ def chat(
)
return response.strip() if isinstance(response, str) else response

def _cached_generate(self, prompt: str, system: str = "", format: str = "") -> str:
# Hash the prompt to use as the cache key
prompt_hash = self._generate_hash(
self.model_name + "\n" + system + "\n" + prompt
)

# Check if prompt response is in cache
response = self.disk_cache.get(prompt_hash)

if response is None:
# If not in cache, make request to client
response = self.client.generate(
model=self.model_name, prompt=prompt, system=system, format=format
)

# Cache the response with hashed prompt as key
self.disk_cache.set(prompt_hash, response)

return response

def generate(self, prompt: str, system: str = "") -> str:
self.logger.info("Generating text with prompt: %s...", prompt[:850])
response = self._cached_generate(prompt=prompt, system=system)
self.logger.info("Generated text: %s...", response["response"][:850])
return response["response"].strip()

def _strip_text_from_json_response(self, response: str) -> str:
pattern = r"^[^{\[]*([{\[].*[}\]])[^}\]]*$"
match = re.search(pattern, response, re.DOTALL)
Expand Down
57 changes: 0 additions & 57 deletions tests/planai/test_llm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,6 @@ def setUp(self):
self.response_content = "Paris"
self.response_data = {"message": {"content": self.response_content}}

def test_generate_with_cache_miss(self):
self.mock_client.generate.return_value = {"response": self.response_content}

# Call generate
response = self.llm_interface.generate(prompt=self.prompt, system=self.system)

self.mock_client.generate.assert_called_once_with(
model=self.llm_interface.model_name,
prompt=self.prompt,
system=self.system,
format="",
)
# Since we changed to use self.response_content directly
self.assertEqual(response, self.response_content)

def test_generate_with_cache_hit(self):
prompt_hash = self.llm_interface._generate_hash(
self.llm_interface.model_name + "\n" + self.system + "\n" + self.prompt
)
self.llm_interface.disk_cache.set(
prompt_hash, {"response": self.response_content}
)

# Call generate
response = self.llm_interface.generate(prompt=self.prompt, system=self.system)

# Since it's a cache hit, no chat call should happen
self.mock_client.generate.assert_not_called()

# Confirming expected parsing
self.assertEqual(response, self.response_content)

def test_generate_invalid_json_response(self):
# Simulate invalid JSON response
invalid_json_response = {"response": "Not a JSON {...."}
self.mock_client.generate.return_value = invalid_json_response

with patch("planai.llm_interface.logging.Logger") as mock_logger:
self.llm_interface.logger = mock_logger
response = self.llm_interface.generate(
prompt=self.prompt, system=self.system
)

# Expecting the invalid content since there's no parsing
self.assertEqual(response, "Not a JSON {....")

def test_generate_pydantic_valid_response(self):
output_model = DummyPydanticModel(field1="test", field2=42)
valid_json_response = '{"field1": "test", "field2": 42}'
Expand Down Expand Up @@ -114,17 +68,6 @@ def test_generate_pydantic_invalid_response(self):

self.assertIsNone(response) # Expecting None due to parsing error

def test_cached_generate_caching_mechanism(self):
# First call should miss cache and make client call
self.mock_client.generate.return_value = self.response_data
response = self.llm_interface._cached_generate(self.prompt, self.system)
self.assertEqual(response, self.response_data)

# Second call should hit cache, no additional client call
response = self.llm_interface._cached_generate(self.prompt, self.system)
self.mock_client.generate.assert_called_once() # Still called only once
self.assertEqual(response, self.response_data)

def test_generate_pydantic_with_retry_logic_and_prompt_check(self):
# Simulate an invalid JSON response that fails to parse initially
invalid_content = '{"field1": "test"}'
Expand Down

0 comments on commit 4386576

Please sign in to comment.