From 4162ab4398d03da1fd4e8ee022a7c5d4fed1ac2e Mon Sep 17 00:00:00 2001 From: Lucas Fayoux <8889400+lfayoux@users.noreply.github.com> Date: Tue, 15 Aug 2023 05:36:42 -0400 Subject: [PATCH] fix chat tests due to conversation_id being optional (#280) --- cohere/responses/chat.py | 2 +- tests/sync/test_chat.py | 17 ----------------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/cohere/responses/chat.py b/cohere/responses/chat.py index f038ac114..cfbc984c7 100644 --- a/cohere/responses/chat.py +++ b/cohere/responses/chat.py @@ -43,7 +43,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat": response_id=response["response_id"], generation_id=response["generation_id"], message=message, - conversation_id=response["conversation_id"], + conversation_id=response.get("conversation_id"), # optional text=response.get("text"), prompt=response.get("prompt"), # optional chatlog=response.get("chatlog"), # optional diff --git a/tests/sync/test_chat.py b/tests/sync/test_chat.py index 107bde211..f303f2c6d 100644 --- a/tests/sync/test_chat.py +++ b/tests/sync/test_chat.py @@ -12,7 +12,6 @@ class TestChat(unittest.TestCase): def test_simple_success(self): prediction = co.chat("Yo what up?", max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) self.assertTrue(prediction.meta) self.assertTrue(prediction.meta["api_version"]) self.assertTrue(prediction.meta["api_version"]["version"]) @@ -23,12 +22,10 @@ def test_multi_replies(self): for _ in range(num_replies): prediction = prediction.respond("oh that's cool", max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) def test_valid_model(self): prediction = co.chat("Yo what up?", model="medium", max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) def test_invalid_model(self): with self.assertRaises(cohere.CohereError): @@ -37,28 +34,24 @@ def test_invalid_model(self): def test_return_chatlog(self): prediction = co.chat("Yo what up?", return_chatlog=True, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) self.assertIsNotNone(prediction.chatlog) self.assertGreaterEqual(len(prediction.chatlog), len(prediction.text)) def test_return_chatlog_false(self): prediction = co.chat("Yo what up?", return_chatlog=False, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) assert prediction.chatlog is None def test_return_prompt(self): prediction = co.chat("Yo what up?", return_prompt=True, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) self.assertIsNotNone(prediction.prompt) self.assertGreaterEqual(len(prediction.prompt), len(prediction.text)) def test_return_prompt_false(self): prediction = co.chat("Yo what up?", return_prompt=False, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) assert prediction.prompt is None def test_preamble_override(self): @@ -67,7 +60,6 @@ def test_preamble_override(self): "Yo what up?", preamble_override=preamble, return_prompt=True, return_preamble=True, max_tokens=5 ) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) self.assertIn(preamble, prediction.prompt) self.assertEqual(preamble, prediction.preamble) @@ -82,7 +74,6 @@ def test_valid_temperatures(self): for temperature in temperatures: prediction = co.chat("Yo what up?", temperature=temperature, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) def test_stream(self): prediction = co.chat( @@ -94,7 +85,6 @@ def test_stream(self): self.assertIsInstance(prediction, cohere.responses.chat.StreamingChat) self.assertIsInstance(prediction.texts, list) self.assertEqual(len(prediction.texts), 0) - self.assertIsNone(prediction.conversation_id) self.assertIsNone(prediction.response_id) self.assertIsNone(prediction.finish_reason) @@ -111,7 +101,6 @@ def test_stream(self): expected_index += 1 self.assertEqual(prediction.texts, [expected_text]) - self.assertIsNotNone(prediction.conversation_id) self.assertIsNotNone(prediction.response_id) self.assertIsNotNone(prediction.finish_reason) @@ -127,7 +116,6 @@ def test_id(self): def test_return_preamble(self): prediction = co.chat("Yo what up?", return_preamble=True, return_prompt=True, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) self.assertIsNotNone(prediction.preamble) self.assertIsNotNone(prediction.prompt) self.assertIn(prediction.preamble, prediction.prompt) @@ -135,7 +123,6 @@ def test_return_preamble(self): def test_return_preamble_false(self): prediction = co.chat("Yo what up?", return_preamble=False, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) assert prediction.preamble is None @@ -151,7 +138,6 @@ def test_chat_history(self): max_tokens=5, ) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) self.assertIsNotNone(prediction.chatlog) self.assertIn("User: Hey!", prediction.prompt) self.assertIn("Chatbot: Hey! How can I help you?", prediction.prompt) @@ -181,7 +167,6 @@ def test_token_count(self): def test_p(self): prediction = co.chat("Yo what up?", p=0.9, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) def test_invalid_p(self): with self.assertRaises(cohere.error.CohereError): @@ -190,7 +175,6 @@ def test_invalid_p(self): def test_k(self): prediction = co.chat("Yo what up?", k=5, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) def test_invalid_k(self): with self.assertRaises(cohere.error.CohereError): @@ -199,7 +183,6 @@ def test_invalid_k(self): def test_logit_bias(self): prediction = co.chat("Yo what up?", logit_bias={42: 10}, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsInstance(prediction.conversation_id, str) def test_invalid_logit_bias(self): invalid = [