Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/continue to generate #401

Merged
merged 17 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def store_conversation(
# Ref: https://stackoverflow.com/questions/63026648/errormessage-class-decimal-inexact-class-decimal-rounded-while
"TotalPrice": decimal(str(conversation.total_price)),
"LastMessageId": conversation.last_message_id,
"ShouldContinue": conversation.should_continue,
}

if conversation.bot_id:
Expand Down Expand Up @@ -236,6 +237,7 @@ def find_conversation_by_id(user_id: str, conversation_id: str) -> ConversationM
},
last_message_id=item["LastMessageId"],
bot_id=item["BotId"] if "BotId" in item else None,
should_continue=item.get("ShouldContinue", False),
)
logger.info(f"Found conversation: {conv}")
return conv
Expand Down
1 change: 1 addition & 0 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ConversationModel(BaseModel):
message_map: dict[str, MessageModel]
last_message_id: str
bot_id: str | None
should_continue: bool


class ConversationMeta(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions backend/app/routes/published_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def post_message(request: Request, message_input: ChatInputWithoutBotId):
message_id=response_message_id,
),
bot_id=bot_id,
continue_generate=message_input.continue_generate,
)

try:
Expand Down
2 changes: 2 additions & 0 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class ChatInput(BaseSchema):
conversation_id: str
message: MessageInput
bot_id: str | None = Field(None)
continue_generate: bool = Field(False)


class ChatOutput(BaseSchema):
Expand Down Expand Up @@ -113,6 +114,7 @@ class Conversation(BaseSchema):
message_map: dict[str, MessageOutput]
last_message_id: str
bot_id: str | None
should_continue: bool


class NewTitleInput(BaseSchema):
Expand Down
1 change: 1 addition & 0 deletions backend/app/routes/schemas/published_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ChatInputWithoutBotId(BaseSchema):
If not provided, new conversation will be generated.""",
)
message: MessageInputWithoutMessageId
continue_generate: bool = Field(False)


class ChatOutputWithoutBotId(BaseSchema):
Expand Down
4 changes: 2 additions & 2 deletions backend/app/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self, args: dict):
)
response = self.on_stop(
OnStopInput(
full_token=concatenated,
full_token=concatenated.rstrip(),
stop_reason=stop_reason,
input_token_count=input_token_count,
output_token_count=output_token_count,
Expand Down Expand Up @@ -134,7 +134,7 @@ def run(self, args: dict):
)
res = self.on_stop(
OnStopInput(
full_token=concatenated,
full_token=concatenated.rstrip(),
stop_reason=stop_reason,
input_token_count=input_token_count,
output_token_count=output_token_count,
Expand Down
67 changes: 42 additions & 25 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,33 +180,36 @@ def prepare_conversation(
message_map=initial_message_map,
last_message_id="",
bot_id=chat_input.bot_id,
should_continue=False,
)

# Append user chat input to the conversation
if chat_input.message.message_id:
message_id = chat_input.message.message_id
else:
message_id = str(ULID())
new_message = MessageModel(
role=chat_input.message.role,
content=[
ContentModel(
content_type=c.content_type,
media_type=c.media_type,
body=c.body,
)
for c in chat_input.message.content
],
model=chat_input.message.model,
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
thinking_log=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore
# If the "Generate continue" button is pressed, a new_message is not generated.
if not chat_input.continue_generate:
new_message = MessageModel(
role=chat_input.message.role,
content=[
ContentModel(
content_type=c.content_type,
media_type=c.media_type,
body=c.body,
)
for c in chat_input.message.content
],
model=chat_input.message.model,
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
thinking_log=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore
statefb marked this conversation as resolved.
Show resolved Hide resolved

return (message_id, conversation, bot)

Expand Down Expand Up @@ -335,7 +338,9 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
messages = trace_to_root(
node_id=chat_input.message.parent_message_id, message_map=message_map
)
messages.append(chat_input.message) # type: ignore

if not chat_input.continue_generate:
messages.append(chat_input.message) # type: ignore

# Create payload to invoke Bedrock
args = compose_args(
Expand All @@ -357,6 +362,8 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
response = get_bedrock_response(args) # type: ignore
reply_txt = response["outputs"][0]["text"] # type: ignore

reply_txt = reply_txt.rstrip()

# Used chunks for RAG generation
if bot and bot.display_retrieved_chunks and is_running_on_lambda():
if len(search_results) > 0:
Expand Down Expand Up @@ -395,14 +402,23 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
used_chunks=used_chunks,
thinking_log=thinking_log,
)
conversation.message_map[assistant_msg_id] = message

# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
if chat_input.continue_generate:
conversation.message_map[conversation.last_message_id].content[
0
].body += reply_txt
else:
conversation.message_map[assistant_msg_id] = message

# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id

conversation.total_price += price

# If continued, save the state
conversation.should_continue = response.stop_reason == "max_tokens"

# Store updated conversation
store_conversation(user_id, conversation)
# Update bot last used time
Expand Down Expand Up @@ -570,6 +586,7 @@ def fetch_conversation(user_id: str, conversation_id: str) -> Conversation:
last_message_id=conversation.last_message_id,
message_map=message_map,
bot_id=conversation.bot_id,
should_continue=conversation.should_continue,
)
return output

Expand Down
79 changes: 46 additions & 33 deletions backend/app/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def process_chat_input(
node_id=chat_input.message.parent_message_id,
message_map=message_map,
)
messages.append(chat_input.message) # type: ignore

if not chat_input.continue_generate:
messages.append(chat_input.message) # type: ignore

args = compose_args(
messages,
Expand All @@ -203,43 +205,54 @@ def on_stream(token: str, **kwargs) -> None:
gatewayapi.post_to_connection(ConnectionId=connection_id, Data=data_to_send)

def on_stop(arg: OnStopInput, **kwargs) -> None:
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
if chat_input.continue_generate:
# For continue generate
conversation.message_map[conversation.last_message_id].content[
0
].body += arg.full_token
else:
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
)
)
)

# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(content_type="text", body=arg.full_token, media_type=None)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(
content_type="text", body=arg.full_token, media_type=None
)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id

conversation.total_price += arg.price

# If continued, save the state
conversation.should_continue = arg.stop_reason == "max_tokens"

# Store conversation before finish streaming so that front-end can avoid 404 issue
store_conversation(user_id, conversation)
last_data_to_send = json.dumps(
Expand Down
6 changes: 6 additions & 0 deletions backend/tests/test_repositories/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_store_and_find_conversation(self):
},
last_message_id="x",
bot_id=None,
should_continue=False,
)

# Test storing conversation
Expand Down Expand Up @@ -186,6 +187,7 @@ def test_store_and_find_conversation(self):
self.assertEqual(found_conversation.last_message_id, "x")
self.assertEqual(found_conversation.total_price, 100)
self.assertEqual(found_conversation.bot_id, None)
self.assertEqual(found_conversation.should_continue, False)

# Test update title
response = change_conversation_title(
Expand Down Expand Up @@ -260,6 +262,7 @@ def test_store_and_find_large_conversation(self):
message_map=large_message_map,
last_message_id="msg_9",
bot_id=None,
should_continue=False,
)

# Test storing large conversation with a small threshold
Expand All @@ -275,6 +278,7 @@ def test_store_and_find_large_conversation(self):
self.assertEqual(found_conversation.total_price, 200)
self.assertEqual(found_conversation.last_message_id, "msg_9")
self.assertEqual(found_conversation.bot_id, None)
self.assertEqual(found_conversation.should_continue, False)

message_map = found_conversation.message_map
self.assertEqual(len(message_map), 10)
Expand Down Expand Up @@ -335,6 +339,7 @@ def setUp(self) -> None:
},
last_message_id="x",
bot_id=None,
should_continue=False,
)
conversation2 = ConversationModel(
id="2",
Expand Down Expand Up @@ -365,6 +370,7 @@ def setUp(self) -> None:
},
last_message_id="x",
bot_id="1",
should_continue=False,
)
bot1 = BotModel(
id="1",
Expand Down
Loading
Loading