Skip to content

Commit

Permalink
fix: remove generate calls in openai and anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Nov 18, 2024
1 parent 4386576 commit 4df5b3a
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 161 deletions.
44 changes: 0 additions & 44 deletions src/planai/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,6 @@ def __init__(self, api_key: str, max_tokens: int = 4096):
self.client = Anthropic(api_key=api_key)
self.max_tokens = max_tokens

def generate(
self,
prompt: str,
system: str = "",
format: Literal["", "json"] = "",
model: str = "claude-3-5-sonnet-20240620",
) -> Mapping[str, Any]:
"""
Create a response using the requested Anthropic model.
Args:
prompt (str): The main input prompt for the model.
system (str, optional): The system message to set the behavior of the assistant. Defaults to ''.
format (Literal['', 'json'], optional): If set to 'json', the response will be in JSON format. Defaults to ''.
model (str, optional): The Anthropic model to use. Defaults to 'claude-3-5-sonnet-20240620'.
max_tokens (int, optional): Maximum number of tokens in the response. Defaults to 4096.
Raises:
Exception: For any API errors.
Returns:
Mapping[str, Any]: A dictionary containing the response and completion status.
"""
messages = []
messages.append({"role": "user", "content": prompt})

try:
response = self.client.messages.create(
max_tokens=self.max_tokens,
messages=messages,
model=model,
system=system,
)

# Extract the text from content blocks
content = "".join(
block.text for block in response.content if block.type == "text"
)

return {"response": content, "done": True}

except APIError as e:
raise e

def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
"""
Conduct a chat conversation using the Anthropic API.
Expand Down
54 changes: 0 additions & 54 deletions src/planai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,60 +22,6 @@ def __init__(self, api_key: str, max_tokens: int = 4096):
self.client = OpenAI(api_key=api_key)
self.max_tokens = max_tokens

def generate(
self,
prompt: str,
system: str = "",
format: Literal["", "json"] = "",
model: str = "gpt-3.5-turbo",
) -> Mapping[str, Any]:
"""
Create a response using the requested OpenAI model.
Args:
prompt (str): The main input prompt for the model.
system (str, optional): The system message to set the behavior of the assistant. Defaults to ''.
format (Literal['', 'json'], optional): If set to 'json', the response will be in JSON format. Defaults to ''.
model (str, optional): The OpenAI model to use. Defaults to 'gpt-3.5-turbo'.
Raises:
Exception: For any API errors.
Returns:
Mapping[str, Any]: A dictionary containing the response and completion status.
"""
messages = []

if system:
messages.append({"role": "system", "content": system})
elif format == "json":
messages.append(
{
"role": "system",
"content": "You are a helpful assistant that responds in JSON format.",
}
)

messages.append({"role": "user", "content": prompt})

api_params = {
"model": model,
"messages": messages,
"max_tokens": self.max_tokens,
}

if format == "json":
api_params["response_format"] = {"type": "json_object"}

try:
response = self.client.chat.completions.create(**api_params)
content = response.choices[0].message.content

return {"response": content, "done": True}

except Exception as e:
raise e

def chat(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
"""
Conduct a chat conversation using the OpenAI API, with optional structured output.
Expand Down
63 changes: 0 additions & 63 deletions tests/planai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,59 +20,6 @@ def setUp(self):
self.api_key = "test_api_key"
self.openai_wrapper = OpenAIWrapper(api_key=self.api_key)

def test_generate_basic_prompt(self):
# Mock the response from the client
self.mock_client.chat.completions.create.return_value = Mock(
choices=[Mock(message=Mock(content="Test response content"))]
)

prompt = "Test prompt"
response = self.openai_wrapper.generate(prompt=prompt)

self.assertEqual(response, {"response": "Test response content", "done": True})

# Verify that the client's chat.completions.create was called with correct parameters
self.mock_client.chat.completions.create.assert_called_once()
args, kwargs = self.mock_client.chat.completions.create.call_args
self.assertEqual(kwargs["model"], "gpt-3.5-turbo")
self.assertEqual(kwargs["max_tokens"], 4096)
self.assertEqual(kwargs["messages"], [{"role": "user", "content": prompt}])

def test_generate_with_system_message(self):
self.mock_client.chat.completions.create.return_value = Mock(
choices=[Mock(message=Mock(content="System message response"))]
)

prompt = "Test prompt"
system = "System message"
response = self.openai_wrapper.generate(prompt=prompt, system=system)

self.assertEqual(
response, {"response": "System message response", "done": True}
)

self.mock_client.chat.completions.create.assert_called_once()
args, kwargs = self.mock_client.chat.completions.create.call_args
expected_messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
self.assertEqual(kwargs["messages"], expected_messages)

def test_generate_with_json_format(self):
self.mock_client.chat.completions.create.return_value = Mock(
choices=[Mock(message=Mock(content='{"key": "value"}'))]
)

prompt = "Test prompt"
response = self.openai_wrapper.generate(prompt=prompt, format="json")

self.assertEqual(response, {"response": '{"key": "value"}', "done": True})

self.mock_client.chat.completions.create.assert_called_once()
args, kwargs = self.mock_client.chat.completions.create.call_args
self.assertEqual(kwargs["response_format"], {"type": "json_object"})

def test_chat_basic(self):
self.mock_client.chat.completions.create.return_value = Mock(
choices=[Mock(message=Mock(content="Chat response content"))],
Expand Down Expand Up @@ -166,16 +113,6 @@ def test_chat_with_options(self):
args, kwargs = self.mock_client.chat.completions.create.call_args
self.assertEqual(kwargs["temperature"], 0.7)

def test_generate_exception_propagation(self):
self.mock_client.chat.completions.create.side_effect = Exception(
"Test exception"
)

with self.assertRaises(Exception) as context:
self.openai_wrapper.generate(prompt="Test prompt")

self.assertEqual(str(context.exception), "Test exception")

def test_chat_exception_propagation(self):
self.mock_client.chat.completions.create.side_effect = Exception(
"Test exception"
Expand Down

0 comments on commit 4df5b3a

Please sign in to comment.