Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/continue generate #361

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def store_conversation(
# Ref: https://stackoverflow.com/questions/63026648/errormessage-class-decimal-inexact-class-decimal-rounded-while
"TotalPrice": decimal(str(conversation.total_price)),
"LastMessageId": conversation.last_message_id,
"ShouldContinue": conversation.should_continue,
}

if conversation.bot_id:
Expand Down Expand Up @@ -236,6 +237,7 @@ def find_conversation_by_id(user_id: str, conversation_id: str) -> ConversationM
},
last_message_id=item["LastMessageId"],
bot_id=item["BotId"] if "BotId" in item else None,
should_continue=item.get("ShouldContinue", False),
)
logger.info(f"Found conversation: {conv}")
return conv
Expand Down
1 change: 1 addition & 0 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ConversationModel(BaseModel):
message_map: dict[str, MessageModel]
last_message_id: str
bot_id: str | None
should_continue: bool


class ConversationMeta(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class Conversation(BaseSchema):
message_map: dict[str, MessageOutput]
last_message_id: str
bot_id: str | None
should_continue: bool


class NewTitleInput(BaseSchema):
Expand Down
49 changes: 29 additions & 20 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,33 +161,41 @@ def prepare_conversation(
message_map=initial_message_map,
last_message_id="",
bot_id=chat_input.bot_id,
should_continue=False,
)

# Append user chat input to the conversation
if chat_input.message.message_id:
message_id = chat_input.message.message_id
else:
message_id = str(ULID())
new_message = MessageModel(
role=chat_input.message.role,
content=[
ContentModel(
content_type=c.content_type,
media_type=c.media_type,
body=c.body,
)
for c in chat_input.message.content
],
model=chat_input.message.model,
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
thinking_log=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore

# TODO: Continueが空メッセージという仕様で良いか要検討
# Empty messages for continuity purposes are not added.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 空の判定はddbから読み取ったshould_continueから判別可能だと思ったのですが、ここで敢えて空の判定を挟んでいる理由を教えてください(打ち合わせ時に教えていただいたと思うのですが、忘れてしまいました。申し訳ないのですが、ここに記録残していただけると助かります。またそれをコード中のコメントで英語で記載しておいていただけると助かります)

仮に必要な場合は、bedrock claude chatは空メッセージをフロントから受け付けない仕様なので、空の判定で良いと思います!

  • 空であることの判定は、websocket.pyに記載いただいているものと同様のもので良いと思いましたが、ここでtextかどうかも判定している理由はなんですか?
  • new_messageの直前に、continueボタンによる分岐である旨をコメントで記載お願いします!

Copy link
Contributor Author

@satoxiw satoxiw Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空の判定はddbから読み取ったshould_continueから判別可能だと思ったのですが、ここで敢えて空の判定を挟んでいる理由を教えてください(打ち合わせ時に教えていただいたと思うのですが、忘れてしまいました。申し訳ないのですが、ここに記録残していただけると助かります。またそれをコード中のコメントで英語で記載しておいていただけると助かります)

仮にmax_tokensとなって続きがあったとしても、それを無視して別のリクエストを送るケースもあるかもという想定からです。should_continueに頼ってしまうとユーザの意図した動作とならない可能性があります。

仮に必要な場合は、bedrock claude chatは空メッセージをフロントから受け付けない仕様なので、空の判定で良いと思います!

了解です。

空であることの判定は、websocket.pyに記載いただいているものと同様のもので良いと思いましたが、ここでtextかどうかも判定している理由はなんですか?

単にtextでない場合の仕様をしっかり読み解けておらず、imageの場合にbodyが空の場合があるのかな?という想定からです。content_tyepが"image"の場合もbodyが空という状況はない感じでしょうか?

new_messageの直前に、continueボタンによる分岐である旨をコメントで記載お願いします!

分岐の直前にコメントで入れておきます。

Copy link
Contributor

@statefb statefb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仮にmax_tokensとなって続きがあったとしても、それを無視して別のリクエストを送るケースもあるかもという想定からです。should_continueに頼ってしまうとユーザの意図した動作とならない可能性があります。

ありがとうございます、理解しました!

単にtextでない場合の仕様をしっかり読み解けておらず、imageの場合にbodyが空の場合があるのかな?という想定からです。content_tyepが"image"の場合もbodyが空という状況はない感じでしょうか?

別途slackでの議論の通りなのですが、空文字に意味を持たせたくないので、項目追加でお願いします!下記のようにChatInput (BE)と、PostMessageRequest (FE) に追加すればOKだと思います!

class ChatInput(BaseSchema):
    conversation_id: str
    message: MessageInput
    bot_id: str | None = Field(None)
+    continue_generate: bool
export type PostMessageRequest = {
  conversationId?: string;
  message: MessageContent & {
    parentMessageId: null | string;
  };
  botId?: string;
+  continueGenerate: bool
};

if (
chat_input.message.content[0].content_type == "text"
and chat_input.message.content[0].body != ""
):
new_message = MessageModel(
role=chat_input.message.role,
content=[
ContentModel(
content_type=c.content_type,
media_type=c.media_type,
body=c.body,
)
for c in chat_input.message.content
],
model=chat_input.message.model,
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
thinking_log=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore

return (message_id, conversation, bot)

Expand Down Expand Up @@ -492,6 +500,7 @@ def fetch_conversation(user_id: str, conversation_id: str) -> Conversation:
last_message_id=conversation.last_message_id,
message_map=message_map,
bot_id=conversation.bot_id,
should_continue=conversation.should_continue,
)
return output

Expand Down
90 changes: 58 additions & 32 deletions backend/app/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,18 @@ def process_chat_input(
node_id=chat_input.message.parent_message_id,
message_map=message_map,
)
messages.append(chat_input.message) # type: ignore

continueGenerate = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仮にddbに永続化されたものだけで判定が不可能である場合、この変数(continueGenerate)はprepare_conversationで返したほうが良いかと思いました(タプルに追加する)。
また細かいですが、should_continueではなくあえて変数名を変えた理由は何でしょうか?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仮にddbに永続化されたものだけで判定が不可能である場合、この変数(continueGenerate)はprepare_conversationで返したほうが良いかと思いました(タプルに追加する)。

これは↑の改修とすこし絡みそうなので相談させてください

また細かいですが、should_continueではなくあえて変数名を変えた理由は何でしょうか?

単に修正漏れです。こちらの変数をshould_continue作成前に作っていたためです。名前合わせておきます。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上記のコメント(continue_generateをinputに追加・保存前にtrim)を反映した場合、下記のようにシンプルになるかと思いましたが、いかがでしょうか?

if not chat_input.continue_generate:
  messages.append(chat_input.message)

args = compose_args(
......

def on_stop(arg: OnStopInput, **kwargs) -> None:
        if chat_input.continue_generate:
          # For continue generate
            conversation.message_map[conversation.last_message_id].content[
                0
            ].body += arg.full_token
         else:
           ......

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

こちら、continue_generate導入に伴い上記の用に修正できるので対応します!


# TODO: 空メッセージだと続けて生成するとしているが、見直しが必要
if chat_input.message.content[0].body != "":
messages.append(chat_input.message) # type: ignore
else:
if messages[-1].role == "assistant":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • assisntatかどうかの判別をあえて入れている理由は何でしょうか?(空メッセージでない = 続きが必要 = 自動的に最後のメッセージはassistant)
  • 関数の責務を考慮すると、こちらではなくprepare_conversationで処理したほうが可読性が高いと思うので、strip処理はそちらに記載お願いします!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assisntatかどうかの判別をあえて入れている理由は何でしょうか?(空メッセージでない = 続きが必要 = 自動的に最後のメッセージはassistant)

ここでtrim処理をしなければならないメッセージは必ずassistantの継続メッセージのみとなるため、安全のため(なんらかの仕様で空メッセージが入ってきた場合を考慮して)入れているだけなので、不要かもですね。

関数の責務を考慮すると、こちらではなくprepare_conversationで処理したほうが可読性が高いと思うので、strip処理はそちらに記載お願いします!

となると、前段で行なっているtrace_to_root処理をprepare_conversationで行うことになり冗長な動きになるか、messagesprepare_conversationで作るとした場合、process_chat_inputで行なっている他の処理とのからみもあり大改修になりそうですが、いかがしましょう。

Copy link
Contributor

@statefb statefb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

大改修になりそうですが、いかがしましょう。

すみません、確かにtrace_to_root後に適用する必要がありますね。ご指摘ありがとうございます!
意図としては、この部分で細かな前処理をなるべく避けたい、というのが意図になります。例えば保存する前にtrimするというのはいかがでしょうか(問題なく出力できるかの簡単な検証は必要だと思いますが)。
ストリーム処理の場合、backend.app.stream.pyAnthropicStreamHandlerBedrockStreamHandlerの該当箇所を下記のように変更することで、保存前にtrimが可能です。

AnthropicStreamHandlerの場合:

class AnthropicStreamHandler(BaseStreamHandler):
    """Stream handler for Anthropic models."""

    def run(self, args: dict):
       ...
            if isinstance(event, ContentBlockDeltaEvent):
                ...
            elif isinstance(event, MessageDeltaEvent):
                ...
            elif isinstance(event, MessageStopEvent):
                concatenated = "".join(completions)
                metrics = event.model_dump()["amazon-bedrock-invocationMetrics"]
                input_token_count = metrics.get("inputTokenCount")
                output_token_count = metrics.get("outputTokenCount")
                price = calculate_price(
                    self.model, input_token_count, output_token_count
                )
                response = self.on_stop(
                    OnStopInput(
+                        full_token=concatenated.strip(),
                        stop_reason=stop_reason,
                        input_token_count=input_token_count,
                        output_token_count=output_token_count,
                        price=price,
                    )
                )
                yield response
            else:
                continue

非ストリーム処理で利用されるbackend.app.usecases.chat.pychat()では、reply_txtstripする方針が綺麗にシンプルに実装できるかなと思いましたが、いかがでしょうか?

messages[-1].content[0].body = (
messages[-1].content[0].body.strip()
) # TODO: ここでstripをすることで、最終的なメッセージに影響が出ないか確認
continueGenerate = True

args = compose_args(
messages,
Expand All @@ -203,43 +214,58 @@ def on_stream(token: str, **kwargs) -> None:
gatewayapi.post_to_connection(ConnectionId=connection_id, Data=data_to_send)

def on_stop(arg: OnStopInput, **kwargs) -> None:
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
if not continueGenerate:
statefb marked this conversation as resolved.
Show resolved Hide resolved
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
)
)

# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(
content_type="text", body=arg.full_token, media_type=None
)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id

# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(content_type="text", body=arg.full_token, media_type=None)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
else:
# For continue generate
conversation.message_map[conversation.last_message_id].content[
0
].body += arg.full_token

conversation.total_price += arg.price

# If continued, save the state
if arg.stop_reason == "max_tokens":
statefb marked this conversation as resolved.
Show resolved Hide resolved
conversation.should_continue = True
else:
conversation.should_continue = False

# Store conversation before finish streaming so that front-end can avoid 404 issue
store_conversation(user_id, conversation)
last_data_to_send = json.dumps(
Expand Down
4 changes: 4 additions & 0 deletions backend/tests/test_repositories/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_store_and_find_conversation(self):
},
last_message_id="x",
bot_id=None,
should_continue=False,
)

# Test storing conversation
Expand Down Expand Up @@ -185,6 +186,7 @@ def test_store_and_find_conversation(self):
self.assertEqual(found_conversation.last_message_id, "x")
self.assertEqual(found_conversation.total_price, 100)
self.assertEqual(found_conversation.bot_id, None)
self.assertEqual(found_conversation.should_continue, False)

# Test update title
response = change_conversation_title(
Expand Down Expand Up @@ -259,6 +261,7 @@ def test_store_and_find_large_conversation(self):
message_map=large_message_map,
last_message_id="msg_9",
bot_id=None,
should_continue=False,
)

# Test storing large conversation with a small threshold
Expand All @@ -274,6 +277,7 @@ def test_store_and_find_large_conversation(self):
self.assertEqual(found_conversation.total_price, 200)
self.assertEqual(found_conversation.last_message_id, "msg_9")
self.assertEqual(found_conversation.bot_id, None)
self.assertEqual(found_conversation.should_continue, False)

message_map = found_conversation.message_map
self.assertEqual(len(message_map), 10)
Expand Down
1 change: 1 addition & 0 deletions frontend/src/@types/conversation.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ export type MessageMap = {

export type Conversation = ConversationMeta & {
messageMap: MessageMap;
shouldContinue: boolean;
};

export type PutFeedbackRequest = {
Expand Down
13 changes: 7 additions & 6 deletions frontend/src/components/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,13 @@ const ChatMessage: React.FC<Props> = (props) => {
);
} else {
return (
<React.Fragment key={idx}>
{content.body.split('\n').map((c, idxBody) => (
<div key={idxBody}>{c}</div>
))}
</React.Fragment>
);
// [Customize]インプットメッセージもMarkdown書式で整形表示できるよう修正
<ChatMessageMarkdown
key={idx}
messageId={String(idx)}>
{content.body}
</ChatMessageMarkdown>
);
Comment on lines +160 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自身の入力データはRawな状態で表示しておきたいです。
LLMのプロンプトはある種のコマンドだと思うので加工はせずに、入力した内容をそのまま表示させておきたいという意図です。
マークダウンとRawな入力値を切り替え表示できるとベターな気がしますが、一旦このPRでは対応なしにしておきたいです

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ファイルアップロード時に表示する関係でマークダウンにしたと思うんですが、ファイルの中身は基本的に非表示として、クリックしたら中身が見れるようなUXにしたいです(コードを読み進めて気づきました)
利用者としては、画面上でファイルをアップロードしたという事実だけがわかれば良くて、ファイルの中身は追加アクションで確認できる方がUXが良いかなと思っています(長いファイルだと、ファイルの中身が表示されて画面が占有されてしまうため)

}
})}
<ModalDialog
Expand Down
32 changes: 21 additions & 11 deletions frontend/src/components/ChatMessageMarkdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,35 @@ const ChatMessageMarkdown: React.FC<Props> = ({
// @ts-ignore
rehypePlugins={rehypePlugins}
components={{
// [Customize]ファイル名表示できるようにカスタマイズ
pre({children}) {
return (<pre className='code-container'>{children}</pre>)
},
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
// eslint-disable-next-line @typescript-eslint/no-unused-vars
code({ node, inline, className, children, ...props }) {
const match = /language-(\w+)/.exec(className || '');
const codeText = onlyText(children).replace(/\n$/, '');
const filename = match ? className?.split(":")[1] ?? undefined : undefined;

return !inline && match ? (
<CopyToClipboard codeText={codeText}>
<SyntaxHighlighter
{...props}
children={codeText}
style={vscDarkPlus}
language={match[1]}
x
PreTag="div"
wrapLongLines={true}
/>
</CopyToClipboard>
// [Customize]ファイル名を表示できるようにカスタマイズ
<div>
{filename && (<div className="code-header">{filename}</div>)}
<CopyToClipboard codeText={codeText}>
<SyntaxHighlighter
{...props}
children={codeText}
style={vscDarkPlus}
filename={filename}
language={match[1]}
x
PreTag="div"
wrapLongLines={true}
/>
</CopyToClipboard>
</div>
) : (
<code {...props} className={className}>
{children}
Expand Down
Loading
Loading