Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/continue to generate #401

Merged
merged 17 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def store_conversation(
# Ref: https://stackoverflow.com/questions/63026648/errormessage-class-decimal-inexact-class-decimal-rounded-while
"TotalPrice": decimal(str(conversation.total_price)),
"LastMessageId": conversation.last_message_id,
"ShouldContinue": conversation.should_continue,
}

if conversation.bot_id:
Expand Down Expand Up @@ -236,6 +237,7 @@ def find_conversation_by_id(user_id: str, conversation_id: str) -> ConversationM
},
last_message_id=item["LastMessageId"],
bot_id=item["BotId"] if "BotId" in item else None,
should_continue=item.get("ShouldContinue", False),
)
logger.info(f"Found conversation: {conv}")
return conv
Expand Down
1 change: 1 addition & 0 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ConversationModel(BaseModel):
message_map: dict[str, MessageModel]
last_message_id: str
bot_id: str | None
should_continue: bool


class ConversationMeta(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions backend/app/routes/published_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def post_message(request: Request, message_input: ChatInputWithoutBotId):
message_id=response_message_id,
),
bot_id=bot_id,
continue_generate=False,
)

try:
Expand Down
2 changes: 2 additions & 0 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ChatInput(BaseSchema):
conversation_id: str
message: MessageInput
bot_id: str | None = Field(None)
continue_generate: bool | None = Field(False)
statefb marked this conversation as resolved.
Show resolved Hide resolved


class ChatOutput(BaseSchema):
Expand Down Expand Up @@ -113,6 +114,7 @@ class Conversation(BaseSchema):
message_map: dict[str, MessageOutput]
last_message_id: str
bot_id: str | None
should_continue: bool


class NewTitleInput(BaseSchema):
Expand Down
4 changes: 2 additions & 2 deletions backend/app/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self, args: dict):
)
response = self.on_stop(
OnStopInput(
full_token=concatenated,
full_token=concatenated.strip(),
stop_reason=stop_reason,
input_token_count=input_token_count,
output_token_count=output_token_count,
Expand Down Expand Up @@ -134,7 +134,7 @@ def run(self, args: dict):
)
res = self.on_stop(
OnStopInput(
full_token=concatenated,
full_token=concatenated.strip(),
stop_reason=stop_reason,
input_token_count=input_token_count,
output_token_count=output_token_count,
Expand Down
44 changes: 24 additions & 20 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,33 +180,36 @@ def prepare_conversation(
message_map=initial_message_map,
last_message_id="",
bot_id=chat_input.bot_id,
should_continue=False,
)

# Append user chat input to the conversation
if chat_input.message.message_id:
message_id = chat_input.message.message_id
else:
message_id = str(ULID())
new_message = MessageModel(
role=chat_input.message.role,
content=[
ContentModel(
content_type=c.content_type,
media_type=c.media_type,
body=c.body,
)
for c in chat_input.message.content
],
model=chat_input.message.model,
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
thinking_log=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore
# If the "Generate continue" button is pressed, a new_message is not generated.
if not chat_input.continue_generate:
new_message = MessageModel(
role=chat_input.message.role,
content=[
ContentModel(
content_type=c.content_type,
media_type=c.media_type,
body=c.body,
)
for c in chat_input.message.content
],
model=chat_input.message.model,
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
thinking_log=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore
statefb marked this conversation as resolved.
Show resolved Hide resolved

return (message_id, conversation, bot)

Expand Down Expand Up @@ -570,6 +573,7 @@ def fetch_conversation(user_id: str, conversation_id: str) -> Conversation:
last_message_id=conversation.last_message_id,
message_map=message_map,
bot_id=conversation.bot_id,
should_continue=conversation.should_continue,
)
return output

Expand Down
79 changes: 46 additions & 33 deletions backend/app/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def process_chat_input(
node_id=chat_input.message.parent_message_id,
message_map=message_map,
)
messages.append(chat_input.message) # type: ignore

if not chat_input.continue_generate:
messages.append(chat_input.message) # type: ignore

args = compose_args(
messages,
Expand All @@ -203,43 +205,54 @@ def on_stream(token: str, **kwargs) -> None:
gatewayapi.post_to_connection(ConnectionId=connection_id, Data=data_to_send)

def on_stop(arg: OnStopInput, **kwargs) -> None:
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
if chat_input.continue_generate:
# For continue generate
conversation.message_map[conversation.last_message_id].content[
0
].body += arg.full_token
else:
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
)
)
)

# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(content_type="text", body=arg.full_token, media_type=None)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(
content_type="text", body=arg.full_token, media_type=None
)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id

conversation.total_price += arg.price

# If continued, save the state
conversation.should_continue = arg.stop_reason == "max_tokens"

# Store conversation before finish streaming so that front-end can avoid 404 issue
store_conversation(user_id, conversation)
last_data_to_send = json.dumps(
Expand Down
6 changes: 6 additions & 0 deletions backend/tests/test_repositories/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_store_and_find_conversation(self):
},
last_message_id="x",
bot_id=None,
should_continue=False,
)

# Test storing conversation
Expand Down Expand Up @@ -186,6 +187,7 @@ def test_store_and_find_conversation(self):
self.assertEqual(found_conversation.last_message_id, "x")
self.assertEqual(found_conversation.total_price, 100)
self.assertEqual(found_conversation.bot_id, None)
self.assertEqual(found_conversation.should_continue, False)

# Test update title
response = change_conversation_title(
Expand Down Expand Up @@ -260,6 +262,7 @@ def test_store_and_find_large_conversation(self):
message_map=large_message_map,
last_message_id="msg_9",
bot_id=None,
should_continue=False,
)

# Test storing large conversation with a small threshold
Expand All @@ -275,6 +278,7 @@ def test_store_and_find_large_conversation(self):
self.assertEqual(found_conversation.total_price, 200)
self.assertEqual(found_conversation.last_message_id, "msg_9")
self.assertEqual(found_conversation.bot_id, None)
self.assertEqual(found_conversation.should_continue, False)

message_map = found_conversation.message_map
self.assertEqual(len(message_map), 10)
Expand Down Expand Up @@ -335,6 +339,7 @@ def setUp(self) -> None:
},
last_message_id="x",
bot_id=None,
should_continue=False,
)
conversation2 = ConversationModel(
id="2",
Expand Down Expand Up @@ -365,6 +370,7 @@ def setUp(self) -> None:
},
last_message_id="x",
bot_id="1",
should_continue=False,
)
bot1 = BotModel(
id="1",
Expand Down
15 changes: 15 additions & 0 deletions backend/tests/test_usecases/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def setUp(self) -> None:
),
},
bot_id=None,
should_continue=False,
),
)

Expand All @@ -345,6 +346,7 @@ def test_continue_chat(self):
message_id=None,
),
bot_id=None,
continue_generate=False,
)
output: ChatOutput = chat(user_id=self.user_id, chat_input=chat_input)
self.output = output
Expand Down Expand Up @@ -449,6 +451,7 @@ def setUp(self) -> None:
),
},
bot_id=None,
should_continue=False,
),
)

Expand All @@ -471,6 +474,7 @@ def test_chat(self):
message_id=None,
),
bot_id=None,
continue_generate=False,
)
output: ChatOutput = chat(user_id=self.user_id, chat_input=chat_input)
self.output = output
Expand All @@ -497,6 +501,7 @@ def test_chat(self):
message_id=None,
),
bot_id=None,
continue_generate=False,
)
output: ChatOutput = chat(user_id=self.user_id, chat_input=chat_input)
self.output = output
Expand Down Expand Up @@ -528,6 +533,7 @@ def setUp(self) -> None:
message_id=None,
),
bot_id=None,
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)
print(output)
Expand Down Expand Up @@ -599,6 +605,7 @@ def test_chat_with_private_bot(self):
message_id=None,
),
bot_id="private1",
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)
print(output)
Expand All @@ -625,6 +632,7 @@ def test_chat_with_private_bot(self):
message_id=None,
),
bot_id="private1",
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)
print(output)
Expand All @@ -646,6 +654,7 @@ def test_chat_with_private_bot(self):
message_id=None,
),
bot_id="private1",
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)

Expand All @@ -671,6 +680,7 @@ def test_chat_with_public_bot(self):
message_id=None,
),
bot_id="public1",
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)

Expand All @@ -693,6 +703,7 @@ def test_chat_with_public_bot(self):
message_id=None,
),
bot_id="private1",
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)
print(output)
Expand All @@ -717,6 +728,7 @@ def test_fetch_conversation(self):
message_id=None,
),
bot_id="private1",
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)

Expand Down Expand Up @@ -767,6 +779,7 @@ def test_agent_chat(self):
message_id=None,
),
bot_id=self.bot_id,
continue_generate=False,
)
output: ChatOutput = chat(user_id=self.user_name, chat_input=chat_input)
print(output.message.content[0].body)
Expand Down Expand Up @@ -846,6 +859,7 @@ def test_insert_knowledge(self):
},
bot_id="bot1",
last_message_id="1-user",
continue_generate=False,
)
conversation_with_context = insert_knowledge(
conversation, results, display_citation=True
Expand All @@ -872,6 +886,7 @@ def test_streaming_api(self):
message_id=None,
),
bot_id=None,
continue_generate=False,
)
user_msg_id, conversation, bot = prepare_conversation("user1", chat_input)
messages = trace_to_root(
Expand Down
Loading
Loading