Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add "Continue Generate" and "Generate stop" button #280

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backend/app/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
NewTitleInput,
ProposedTitle,
RelatedDocumentsOutput,
StopGenerateOutput,
)
from app.usecases.chat import (
chat,
stop_generate,
fetch_conversation,
fetch_related_documents,
propose_conversation_title,
Expand All @@ -40,6 +42,18 @@ def post_message(request: Request, chat_input: ChatInput):
return output


@router.post(
"/conversation/stop-generate", response_model=StopGenerateOutput
)
def post_stop_generate(request: Request, conversation_id: str):
"""Stop generate chat message
NOTE: POST method is used to avoid query string length limit
"""
current_user: User = request.state.current_user
output = stop_generate(user_id=current_user.id, conversation_id=conversation_id)
return output


@router.post(
"/conversation/related-documents", response_model=list[RelatedDocumentsOutput]
)
Expand Down
5 changes: 5 additions & 0 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class ChatOutput(BaseSchema):
create_time: float


class StopGenerateOutput(BaseSchema):
stop_generate: bool
message: str | None


class RelatedDocumentsOutput(BaseSchema):
chunk_body: str
content_type: Literal["s3", "url"]
Expand Down
15 changes: 15 additions & 0 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def insert_knowledge(
def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
user_msg_id, conversation, bot = prepare_conversation(user_id, chat_input)

# 1:DynamoDBの該当のレコードの中止フラグを折る

message_map = conversation.message_map
if bot and is_running_on_lambda():
# NOTE: `is_running_on_lambda`is a workaround for local testing due to no postgres mock.
Expand Down Expand Up @@ -304,6 +306,9 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
)
conversation.message_map[assistant_msg_id] = message

# 3-2:DynamoDBの該当のレコードの中止フラグを確認し、中止フラグが立っていた場合は送信処理を中止。
# ユーザの入力のみ保存はする。

# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
Expand Down Expand Up @@ -347,6 +352,16 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:

return output

def stop_generate(
user_id: str,
conversation_id: str,
) -> bool:
try:
# 2:DynamoDBの該当のレコードに中止フラグを立てる
return True
except:
return False


def propose_conversation_title(
user_id: str,
Expand Down
3 changes: 3 additions & 0 deletions backend/app/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def process_chat_input(
conversation.total_price += price

store_conversation(user_id, conversation)

# 3-2:DynamoDBの該当のレコードの中止フラグを確認し、中止フラグが立っていた場合はそれ以上の推論処理を中止。
# 現在の状況でDynamoDBに保存する。
else:
continue

Expand Down
5 changes: 5 additions & 0 deletions frontend/src/@types/conversation.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ export type PostMessageResponse = {
message: MessageContent;
};

export type PostStopGenerateResponse = {
stopGenerate: boolean;
message?: string;
};

export type GetRelatedDocumentsRequest = {
conversationId: string;
message: MessageContent & {
Expand Down
47 changes: 31 additions & 16 deletions frontend/src/components/ButtonSend.tsx
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
import React from 'react';
import { PiPaperPlaneRightFill, PiSpinnerGap } from 'react-icons/pi';
import { PiPaperPlaneRightFill } from 'react-icons/pi';
import { BaseProps } from '../@types/common';
import { twMerge } from 'tailwind-merge';
import { FaStop } from "react-icons/fa";
import { ImSpinner8 } from "react-icons/im";

type Props = BaseProps & {
disabled?: boolean;
loading?: boolean;
onClick: () => void;
onClickLoading: () => void;
};

const ButtonSend: React.FC<Props> = (props) => {

return (
<button
className={twMerge(
'flex items-center justify-center rounded-xl border border-aws-sea-blue bg-white p-2 text-xl text-aws-sea-blue',
props.disabled ? 'opacity-30' : '',
props.className
)}
onClick={props.onClick}
disabled={props.disabled || props.loading}>
{props.loading ? (
<PiSpinnerGap className="animate-spin" />
) : (
<PiPaperPlaneRightFill />
)}
</button>
<>
{props.loading ?
<button
className={twMerge(
'flex items-center justify-center rounded-xl border border-aws-sea-blue bg-white p-2 text-xl text-aws-sea-blue',
props.className
)}
onClick={props.onClickLoading}>
<div className="flex items-center justify-center">
<ImSpinner8 className="animate-spin p-n1" />

Check warning on line 26 in frontend/src/components/ButtonSend.tsx

View workflow job for this annotation

GitHub Actions / build

Classname 'p-n1' is not a Tailwind CSS class!
<FaStop className=" absolute text-[0.5rem]" />
</div>
</button>
:
<button
className={twMerge(
'flex items-center justify-center rounded-xl border border-aws-sea-blue bg-white p-2 text-xl text-aws-sea-blue',
props.disabled ? 'opacity-30' : '',
props.className
)}
onClick={props.onClick}
disabled={props.disabled}>
<PiPaperPlaneRightFill />
</button>
}
</>
);
};

Expand Down
2 changes: 2 additions & 0 deletions frontend/src/components/InputChatContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Props = BaseProps & {
dndMode?: boolean;
onSend: (content: string, base64EncodedImages?: string[]) => void;
onRegenerate: () => void;
onStopGenerate: () => void;
};

const MAX_IMAGE_WIDTH = 800;
Expand Down Expand Up @@ -276,6 +277,7 @@ const InputChatContent: React.FC<Props> = (props) => {
disabled={disabledSend || props.disabled}
loading={postingMessage}
onClick={sendContent}
onClickLoading={props.onStopGenerate}
/>
</div>
{base64EncodedImages.length > 0 && (
Expand Down
26 changes: 23 additions & 3 deletions frontend/src/hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,25 @@ const useChat = () => {
}
};

/**
* 生成停止
*/
const postStopGenerate = () => {
conversationApi
.postStopGenerate(conversationId)
.then((res) => {
editMessage(
conversationId,
NEW_MESSAGE_ID.ASSISTANT,
res.data.message.content[0].body
);
resolve(res.data.message.content[0].body);
})
.catch((e) => {
reject(e);
});
}

/**
* 再生成
* @param props content: 内容を上書きしたい場合に設定 messageId: 再生成対象のmessageId botId: ボットの場合は設定する
Expand Down Expand Up @@ -584,6 +603,7 @@ const useChat = () => {
messages,
setCurrentMessageId,
postChat,
postStopGenerate,
regenerate,
getPostedModel,
// エラーのリトライ
Expand All @@ -601,9 +621,9 @@ const useChat = () => {
content: params.content ?? latestMessage.content[0].body,
bot: params.bot
? {
botId: params.bot.botId,
hasKnowledge: params.bot.hasKnowledge,
}
botId: params.bot.botId,
hasKnowledge: params.bot.hasKnowledge,
}
: undefined,
});
} else {
Expand Down
6 changes: 6 additions & 0 deletions frontend/src/hooks/useConversationApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
GetRelatedDocumentsResponse,
PostMessageRequest,
PostMessageResponse,
PostStopGenerateResponse,
} from '../@types/conversation';
import useHttp from './useHttp';

Expand Down Expand Up @@ -38,6 +39,11 @@ const useConversationApi = () => {
...input,
});
},
postStopGenerate: (conversationId: string) => {
return http.post<PostStopGenerateResponse>('conversation/stop-generate', {
conversationId,
});
},
getRelatedDocuments: (input: GetRelatedDocumentsRequest) => {
return http.post<GetRelatedDocumentsResponse>(
'conversation/related-documents',
Expand Down
20 changes: 14 additions & 6 deletions frontend/src/pages/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const ChatPage: React.FC = () => {
const {
postingMessage,
postChat,
postStopGenerate,
messages,
conversationId,
setConversationId,
Expand Down Expand Up @@ -104,9 +105,9 @@ const ChatPage: React.FC = () => {
const inputBotParams = useMemo(() => {
return botId
? {
botId: botId,
hasKnowledge: bot?.hasKnowledge ?? false,
}
botId: botId,
hasKnowledge: bot?.hasKnowledge ?? false,
}
: undefined;
}, [bot?.hasKnowledge, botId]);

Expand All @@ -121,6 +122,13 @@ const ChatPage: React.FC = () => {
[inputBotParams, postChat]
);

const onStopGenerate = useCallback(
() => {
postStopGenerate()
},
[postStopGenerate]
);

const onChangeCurrentMessageId = useCallback(
(messageId: string) => {
setCurrentMessageId(messageId);
Expand Down Expand Up @@ -303,9 +311,8 @@ const ChatPage: React.FC = () => {
messages.map((message, idx) => (
<div
key={idx}
className={`${
message.role === 'assistant' ? 'bg-aws-squid-ink/5' : ''
}`}>
className={`${message.role === 'assistant' ? 'bg-aws-squid-ink/5' : ''
}`}>
<ChatMessage
chatContent={message}
onChangeMessageId={onChangeCurrentMessageId}
Expand Down Expand Up @@ -358,6 +365,7 @@ const ChatPage: React.FC = () => {
}
onSend={onSend}
onRegenerate={onRegenerate}
onStopGenerate={onStopGenerate}
/>
</div>
</div>
Expand Down
Loading