Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
k-nagayama-dxt committed Apr 25, 2024
1 parent 73ed577 commit acaa04d
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 25 deletions.
14 changes: 14 additions & 0 deletions backend/app/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
NewTitleInput,
ProposedTitle,
RelatedDocumentsOutput,
StopGenerateOutput,
)
from app.usecases.chat import (
chat,
stop_generate,
fetch_conversation,
fetch_related_documents,
propose_conversation_title,
Expand All @@ -40,6 +42,18 @@ def post_message(request: Request, chat_input: ChatInput):
return output


@router.post(
"/conversation/stop-generate", response_model=StopGenerateOutput
)
def post_stop_generate(request: Request, conversation_id: str):
"""Stop generate chat message
NOTE: POST method is used to avoid query string length limit
"""
current_user: User = request.state.current_user
output = stop_generate(user_id=current_user.id, conversation_id=conversation_id)
return output


@router.post(
"/conversation/related-documents", response_model=list[RelatedDocumentsOutput]
)
Expand Down
5 changes: 5 additions & 0 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class ChatOutput(BaseSchema):
create_time: float


class StopGenerateOutput(BaseSchema):
stop_generate: bool
message: str | None


class RelatedDocumentsOutput(BaseSchema):
chunk_body: str
content_type: Literal["s3", "url"]
Expand Down
15 changes: 15 additions & 0 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def insert_knowledge(
def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
user_msg_id, conversation, bot = prepare_conversation(user_id, chat_input)

# 1:DynamoDBの該当のレコードの中止フラグを折る

message_map = conversation.message_map
if bot and is_running_on_lambda():
# NOTE: `is_running_on_lambda`is a workaround for local testing due to no postgres mock.
Expand Down Expand Up @@ -304,6 +306,9 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
)
conversation.message_map[assistant_msg_id] = message

# 3-2:DynamoDBの該当のレコードの中止フラグを確認し、中止フラグが立っていた場合は送信処理を中止。
# ユーザの入力のみ保存はする。

# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
Expand Down Expand Up @@ -347,6 +352,16 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:

return output

def stop_generate(
user_id: str,
conversation_id: str,
) -> bool:
try:
# 2:DynamoDBの該当のレコードに中止フラグを立てる
return True
except:
return False


def propose_conversation_title(
user_id: str,
Expand Down
3 changes: 3 additions & 0 deletions backend/app/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def process_chat_input(
conversation.total_price += price

store_conversation(user_id, conversation)

# 3-2:DynamoDBの該当のレコードの中止フラグを確認し、中止フラグが立っていた場合はそれ以上の推論処理を中止。
# 現在の状況でDynamoDBに保存する。
else:
continue

Expand Down
5 changes: 5 additions & 0 deletions frontend/src/@types/conversation.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ export type PostMessageResponse = {
message: MessageContent;
};

export type PostStopGenerateResponse = {
stopGenerate: boolean;
message?: string;
};

export type GetRelatedDocumentsRequest = {
conversationId: string;
message: MessageContent & {
Expand Down
47 changes: 31 additions & 16 deletions frontend/src/components/ButtonSend.tsx
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
import React from 'react';
import { PiPaperPlaneRightFill, PiSpinnerGap } from 'react-icons/pi';
import { PiPaperPlaneRightFill } from 'react-icons/pi';
import { BaseProps } from '../@types/common';
import { twMerge } from 'tailwind-merge';
import { FaStop } from "react-icons/fa";
import { ImSpinner8 } from "react-icons/im";

type Props = BaseProps & {
disabled?: boolean;
loading?: boolean;
onClick: () => void;
onClickLoading: () => void;
};

const ButtonSend: React.FC<Props> = (props) => {

return (
<button
className={twMerge(
'flex items-center justify-center rounded-xl border border-aws-sea-blue bg-white p-2 text-xl text-aws-sea-blue',
props.disabled ? 'opacity-30' : '',
props.className
)}
onClick={props.onClick}
disabled={props.disabled || props.loading}>
{props.loading ? (
<PiSpinnerGap className="animate-spin" />
) : (
<PiPaperPlaneRightFill />
)}
</button>
<>
{props.loading ?
<button
className={twMerge(
'flex items-center justify-center rounded-xl border border-aws-sea-blue bg-white p-2 text-xl text-aws-sea-blue',
props.className
)}
onClick={props.onClickLoading}>
<div className="flex items-center justify-center">
<ImSpinner8 className="animate-spin p-n1" />

Check warning on line 26 in frontend/src/components/ButtonSend.tsx

View workflow job for this annotation

GitHub Actions / build

Classname 'p-n1' is not a Tailwind CSS class!
<FaStop className=" absolute text-[0.5rem]" />
</div>
</button>
:
<button
className={twMerge(
'flex items-center justify-center rounded-xl border border-aws-sea-blue bg-white p-2 text-xl text-aws-sea-blue',
props.disabled ? 'opacity-30' : '',
props.className
)}
onClick={props.onClick}
disabled={props.disabled}>
<PiPaperPlaneRightFill />
</button>
}
</>
);
};

Expand Down
2 changes: 2 additions & 0 deletions frontend/src/components/InputChatContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Props = BaseProps & {
dndMode?: boolean;
onSend: (content: string, base64EncodedImages?: string[]) => void;
onRegenerate: () => void;
onStopGenerate: () => void;
};

const MAX_IMAGE_WIDTH = 800;
Expand Down Expand Up @@ -276,6 +277,7 @@ const InputChatContent: React.FC<Props> = (props) => {
disabled={disabledSend || props.disabled}
loading={postingMessage}
onClick={sendContent}
onClickLoading={props.onStopGenerate}
/>
</div>
{base64EncodedImages.length > 0 && (
Expand Down
26 changes: 23 additions & 3 deletions frontend/src/hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,25 @@ const useChat = () => {
}
};

/**
* 生成停止
*/
const postStopGenerate = () => {
conversationApi
.postStopGenerate(conversationId)
.then((res) => {
editMessage(
conversationId,
NEW_MESSAGE_ID.ASSISTANT,
res.data.message.content[0].body
);
resolve(res.data.message.content[0].body);
})
.catch((e) => {
reject(e);
});
}

/**
* 再生成
* @param props content: 内容を上書きしたい場合に設定 messageId: 再生成対象のmessageId botId: ボットの場合は設定する
Expand Down Expand Up @@ -584,6 +603,7 @@ const useChat = () => {
messages,
setCurrentMessageId,
postChat,
postStopGenerate,
regenerate,
getPostedModel,
// エラーのリトライ
Expand All @@ -601,9 +621,9 @@ const useChat = () => {
content: params.content ?? latestMessage.content[0].body,
bot: params.bot
? {
botId: params.bot.botId,
hasKnowledge: params.bot.hasKnowledge,
}
botId: params.bot.botId,
hasKnowledge: params.bot.hasKnowledge,
}
: undefined,
});
} else {
Expand Down
6 changes: 6 additions & 0 deletions frontend/src/hooks/useConversationApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
GetRelatedDocumentsResponse,
PostMessageRequest,
PostMessageResponse,
PostStopGenerateResponse,
} from '../@types/conversation';
import useHttp from './useHttp';

Expand Down Expand Up @@ -38,6 +39,11 @@ const useConversationApi = () => {
...input,
});
},
postStopGenerate: (conversationId: string) => {
return http.post<PostStopGenerateResponse>('conversation/stop-generate', {
conversationId,
});
},
getRelatedDocuments: (input: GetRelatedDocumentsRequest) => {
return http.post<GetRelatedDocumentsResponse>(
'conversation/related-documents',
Expand Down
20 changes: 14 additions & 6 deletions frontend/src/pages/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const ChatPage: React.FC = () => {
const {
postingMessage,
postChat,
postStopGenerate,
messages,
conversationId,
setConversationId,
Expand Down Expand Up @@ -104,9 +105,9 @@ const ChatPage: React.FC = () => {
const inputBotParams = useMemo(() => {
return botId
? {
botId: botId,
hasKnowledge: bot?.hasKnowledge ?? false,
}
botId: botId,
hasKnowledge: bot?.hasKnowledge ?? false,
}
: undefined;
}, [bot?.hasKnowledge, botId]);

Expand All @@ -121,6 +122,13 @@ const ChatPage: React.FC = () => {
[inputBotParams, postChat]
);

const onStopGenerate = useCallback(
() => {
postStopGenerate()
},
[postStopGenerate]
);

const onChangeCurrentMessageId = useCallback(
(messageId: string) => {
setCurrentMessageId(messageId);
Expand Down Expand Up @@ -303,9 +311,8 @@ const ChatPage: React.FC = () => {
messages.map((message, idx) => (
<div
key={idx}
className={`${
message.role === 'assistant' ? 'bg-aws-squid-ink/5' : ''
}`}>
className={`${message.role === 'assistant' ? 'bg-aws-squid-ink/5' : ''
}`}>
<ChatMessage
chatContent={message}
onChangeMessageId={onChangeCurrentMessageId}
Expand Down Expand Up @@ -358,6 +365,7 @@ const ChatPage: React.FC = () => {
}
onSend={onSend}
onRegenerate={onRegenerate}
onStopGenerate={onStopGenerate}
/>
</div>
</div>
Expand Down

0 comments on commit acaa04d

Please sign in to comment.