diff --git a/backend/app/repositories/conversation.py b/backend/app/repositories/conversation.py index fb1170b67..b30775f02 100644 --- a/backend/app/repositories/conversation.py +++ b/backend/app/repositories/conversation.py @@ -48,6 +48,7 @@ def store_conversation( # Ref: https://stackoverflow.com/questions/63026648/errormessage-class-decimal-inexact-class-decimal-rounded-while "TotalPrice": decimal(str(conversation.total_price)), "LastMessageId": conversation.last_message_id, + "ShouldContinue": conversation.should_continue, } if conversation.bot_id: @@ -236,6 +237,7 @@ def find_conversation_by_id(user_id: str, conversation_id: str) -> ConversationM }, last_message_id=item["LastMessageId"], bot_id=item["BotId"] if "BotId" in item else None, + should_continue=item.get("ShouldContinue", False), ) logger.info(f"Found conversation: {conv}") return conv diff --git a/backend/app/repositories/models/conversation.py b/backend/app/repositories/models/conversation.py index 1118a25a4..12b7fb637 100644 --- a/backend/app/repositories/models/conversation.py +++ b/backend/app/repositories/models/conversation.py @@ -43,6 +43,7 @@ class ConversationModel(BaseModel): message_map: dict[str, MessageModel] last_message_id: str bot_id: str | None + should_continue: bool class ConversationMeta(BaseModel): diff --git a/backend/app/routes/schemas/conversation.py b/backend/app/routes/schemas/conversation.py index 5222b1dba..c61be0a66 100644 --- a/backend/app/routes/schemas/conversation.py +++ b/backend/app/routes/schemas/conversation.py @@ -112,6 +112,7 @@ class Conversation(BaseSchema): message_map: dict[str, MessageOutput] last_message_id: str bot_id: str | None + should_continue: bool class NewTitleInput(BaseSchema): diff --git a/backend/app/usecases/chat.py b/backend/app/usecases/chat.py index 924ebf5e6..6c3b4ef56 100644 --- a/backend/app/usecases/chat.py +++ b/backend/app/usecases/chat.py @@ -161,6 +161,7 @@ def prepare_conversation( message_map=initial_message_map, last_message_id="", bot_id=chat_input.bot_id, + should_continue=False, ) # Append user chat input to the conversation @@ -168,26 +169,32 @@ def prepare_conversation( message_id = chat_input.message.message_id else: message_id = str(ULID()) - new_message = MessageModel( - role=chat_input.message.role, - content=[ - ContentModel( - content_type=c.content_type, - media_type=c.media_type, - body=c.body, - ) - for c in chat_input.message.content - ], - model=chat_input.message.model, - children=[], - parent=parent_id, - create_time=current_time, - feedback=None, - used_chunks=None, - thinking_log=None, - ) - conversation.message_map[message_id] = new_message - conversation.message_map[parent_id].children.append(message_id) # type: ignore + + # If the "Generate continue" button is pressed, a new_message is not generated. + if ( + chat_input.message.content[0].content_type == "text" + and chat_input.message.content[0].body != "" + ): + new_message = MessageModel( + role=chat_input.message.role, + content=[ + ContentModel( + content_type=c.content_type, + media_type=c.media_type, + body=c.body, + ) + for c in chat_input.message.content + ], + model=chat_input.message.model, + children=[], + parent=parent_id, + create_time=current_time, + feedback=None, + used_chunks=None, + thinking_log=None, + ) + conversation.message_map[message_id] = new_message + conversation.message_map[parent_id].children.append(message_id) # type: ignore return (message_id, conversation, bot) @@ -492,6 +499,7 @@ def fetch_conversation(user_id: str, conversation_id: str) -> Conversation: last_message_id=conversation.last_message_id, message_map=message_map, bot_id=conversation.bot_id, + should_continue=conversation.should_continue, ) return output diff --git a/backend/app/websocket.py b/backend/app/websocket.py index ba11b57d7..046e69b1f 100644 --- a/backend/app/websocket.py +++ b/backend/app/websocket.py @@ -181,7 +181,18 @@ def process_chat_input( node_id=chat_input.message.parent_message_id, message_map=message_map, ) - messages.append(chat_input.message) # type: ignore + + should_continue = False + + # TODO: 空メッセージだと続けて生成するとしているが、見直しが必要 + if chat_input.message.content[0].body != "": + messages.append(chat_input.message) # type: ignore + else: + if messages[-1].role == "assistant": + messages[-1].content[0].body = ( + messages[-1].content[0].body.strip() + ) # TODO: ここでstripをすることで、最終的なメッセージに影響が出ないか確認 + should_continue = True args = compose_args( messages, @@ -203,43 +214,54 @@ def on_stream(token: str, **kwargs) -> None: gatewayapi.post_to_connection(ConnectionId=connection_id, Data=data_to_send) def on_stop(arg: OnStopInput, **kwargs) -> None: - used_chunks = None - if bot and bot.display_retrieved_chunks: - if len(search_results) > 0: - used_chunks = [] - for r in filter_used_results(arg.full_token, search_results): - content_type, source_link = get_source_link(r.source) - used_chunks.append( - ChunkModel( - content=r.content, - content_type=content_type, - source=source_link, - rank=r.rank, + if should_continue: + # For continue generate + conversation.message_map[conversation.last_message_id].content[ + 0 + ].body += arg.full_token + else: + used_chunks = None + if bot and bot.display_retrieved_chunks: + if len(search_results) > 0: + used_chunks = [] + for r in filter_used_results(arg.full_token, search_results): + content_type, source_link = get_source_link(r.source) + used_chunks.append( + ChunkModel( + content=r.content, + content_type=content_type, + source=source_link, + rank=r.rank, + ) ) - ) - # Append entire completion as the last message - assistant_msg_id = str(ULID()) - message = MessageModel( - role="assistant", - content=[ - ContentModel(content_type="text", body=arg.full_token, media_type=None) - ], - model=chat_input.message.model, - children=[], - parent=user_msg_id, - create_time=get_current_time(), - feedback=None, - used_chunks=used_chunks, - thinking_log=None, - ) - conversation.message_map[assistant_msg_id] = message - # Append children to parent - conversation.message_map[user_msg_id].children.append(assistant_msg_id) - conversation.last_message_id = assistant_msg_id + # Append entire completion as the last message + assistant_msg_id = str(ULID()) + message = MessageModel( + role="assistant", + content=[ + ContentModel( + content_type="text", body=arg.full_token, media_type=None + ) + ], + model=chat_input.message.model, + children=[], + parent=user_msg_id, + create_time=get_current_time(), + feedback=None, + used_chunks=used_chunks, + thinking_log=None, + ) + conversation.message_map[assistant_msg_id] = message + # Append children to parent + conversation.message_map[user_msg_id].children.append(assistant_msg_id) + conversation.last_message_id = assistant_msg_id conversation.total_price += arg.price + # If continued, save the state + conversation.should_continue = arg.stop_reason == "max_tokens" + # Store conversation before finish streaming so that front-end can avoid 404 issue store_conversation(user_id, conversation) last_data_to_send = json.dumps( diff --git a/backend/tests/test_repositories/test_conversation.py b/backend/tests/test_repositories/test_conversation.py index d2aea1f5b..122939730 100644 --- a/backend/tests/test_repositories/test_conversation.py +++ b/backend/tests/test_repositories/test_conversation.py @@ -148,6 +148,7 @@ def test_store_and_find_conversation(self): }, last_message_id="x", bot_id=None, + should_continue=False, ) # Test storing conversation @@ -185,6 +186,7 @@ def test_store_and_find_conversation(self): self.assertEqual(found_conversation.last_message_id, "x") self.assertEqual(found_conversation.total_price, 100) self.assertEqual(found_conversation.bot_id, None) + self.assertEqual(found_conversation.should_continue, False) # Test update title response = change_conversation_title( @@ -259,6 +261,7 @@ def test_store_and_find_large_conversation(self): message_map=large_message_map, last_message_id="msg_9", bot_id=None, + should_continue=False, ) # Test storing large conversation with a small threshold @@ -274,6 +277,7 @@ def test_store_and_find_large_conversation(self): self.assertEqual(found_conversation.total_price, 200) self.assertEqual(found_conversation.last_message_id, "msg_9") self.assertEqual(found_conversation.bot_id, None) + self.assertEqual(found_conversation.should_continue, False) message_map = found_conversation.message_map self.assertEqual(len(message_map), 10) @@ -334,6 +338,7 @@ def setUp(self) -> None: }, last_message_id="x", bot_id=None, + should_continue=False, ) conversation2 = ConversationModel( id="2", @@ -364,6 +369,7 @@ def setUp(self) -> None: }, last_message_id="x", bot_id="1", + should_continue=False, ) bot1 = BotModel( id="1", diff --git a/backend/tests/test_usecases/test_chat.py b/backend/tests/test_usecases/test_chat.py index eae5a1d87..543e413da 100644 --- a/backend/tests/test_usecases/test_chat.py +++ b/backend/tests/test_usecases/test_chat.py @@ -325,6 +325,7 @@ def setUp(self) -> None: ), }, bot_id=None, + should_continue=False, ), ) @@ -449,6 +450,7 @@ def setUp(self) -> None: ), }, bot_id=None, + should_continue=False, ), ) diff --git a/frontend/src/@types/conversation.d.ts b/frontend/src/@types/conversation.d.ts index 8fc5324bb..6abe7d49d 100644 --- a/frontend/src/@types/conversation.d.ts +++ b/frontend/src/@types/conversation.d.ts @@ -85,6 +85,7 @@ export type MessageMap = { export type Conversation = ConversationMeta & { messageMap: MessageMap; + shouldContinue: boolean; }; export type PutFeedbackRequest = { diff --git a/frontend/src/components/ChatMessage.tsx b/frontend/src/components/ChatMessage.tsx index edc69233c..d6fccc452 100644 --- a/frontend/src/components/ChatMessage.tsx +++ b/frontend/src/components/ChatMessage.tsx @@ -157,12 +157,13 @@ const ChatMessage: React.FC = (props) => { ); } else { return ( - - {content.body.split('\n').map((c, idxBody) => ( -
{c}
- ))} -
- ); + // [Customize]インプットメッセージもMarkdown書式で整形表示できるよう修正 + + {content.body} + + ); } })} = ({ // @ts-ignore rehypePlugins={rehypePlugins} components={{ + // [Customize]ファイル名表示できるようにカスタマイズ + pre({children}) { + return (
{children}
) + }, // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore // eslint-disable-next-line @typescript-eslint/no-unused-vars code({ node, inline, className, children, ...props }) { const match = /language-(\w+)/.exec(className || ''); const codeText = onlyText(children).replace(/\n$/, ''); + const filename = match ? className?.split(":")[1] ?? undefined : undefined; return !inline && match ? ( - - - + // [Customize]ファイル名を表示できるようにカスタマイズ +
+ {filename && (
{filename}
)} + + + +
) : ( {children} diff --git a/frontend/src/components/InputChatContent.tsx b/frontend/src/components/InputChatContent.tsx index 7ce55f405..1d6dd3d86 100644 --- a/frontend/src/components/InputChatContent.tsx +++ b/frontend/src/components/InputChatContent.tsx @@ -9,7 +9,7 @@ import ButtonSend from './ButtonSend'; import Textarea from './Textarea'; import useChat from '../hooks/useChat'; import Button from './Button'; -import { PiArrowsCounterClockwise, PiX } from 'react-icons/pi'; +import { PiArrowsCounterClockwise, PiX, PiFileTextThin, PiArrowFatLineRight } from 'react-icons/pi'; import { TbPhotoPlus } from 'react-icons/tb'; import { useTranslation } from 'react-i18next'; import ButtonIcon from './ButtonIcon'; @@ -28,6 +28,7 @@ type Props = BaseProps & { dndMode?: boolean; onSend: (content: string, base64EncodedImages?: string[]) => void; onRegenerate: () => void; + continueGenerate: () => void; }; const MAX_IMAGE_WIDTH = 800; @@ -38,6 +39,10 @@ const useInputChatContentState = create<{ pushBase64EncodedImage: (encodedImage: string) => void; removeBase64EncodedImage: (index: number) => void; clearBase64EncodedImages: () => void; + textFiles: { name: string, content: string }[]; + pushTextFile: (file: { name: string, content: string }) => void; + removeTextFile: (index: number) => void; + clearTextFiles: () => void; previewImageUrl: string | null; setPreviewImageUrl: (url: string | null) => void; isOpenPreviewImage: boolean; @@ -71,13 +76,37 @@ const useInputChatContentState = create<{ setIsOpenPreviewImage: (isOpen) => { set({ isOpenPreviewImage: isOpen }); }, + textFiles: [], + pushTextFile: (file) => { + set({ + textFiles: produce(get().textFiles, (draft) => { + draft.push(file); + }), + }); + }, + removeTextFile: (index) => { + set({ + textFiles: produce(get().textFiles, (draft) => { + draft.splice(index, 1); + }), + }); + }, + clearTextFiles: () => { + set({ + textFiles: [], + }); + }, })); const InputChatContent: React.FC = (props) => { const { t } = useTranslation(); - const { postingMessage, hasError, messages } = useChat(); + const { postingMessage, hasError, messages, getShouldContinue } = useChat(); const { disabledImageUpload, model, acceptMediaType } = useModel(); + const extendedAcceptMediaType = useMemo(() => { + return [...acceptMediaType, '.md', '.ts', '.js']; // 追加のメディアタイプをここに追加 + }, [acceptMediaType]); + const [content, setContent] = useState(''); const { base64EncodedImages, @@ -88,10 +117,15 @@ const InputChatContent: React.FC = (props) => { setPreviewImageUrl, isOpenPreviewImage, setIsOpenPreviewImage, + textFiles, + pushTextFile, + removeTextFile, + clearTextFiles, } = useInputChatContentState(); useEffect(() => { clearBase64EncodedImages(); + clearTextFiles(); // eslint-disable-next-line react-hooks/exhaustive-deps }, []); @@ -103,20 +137,63 @@ const InputChatContent: React.FC = (props) => { return postingMessage || hasError; }, [hasError, postingMessage]); + const disableContinue = useMemo(() => { + return postingMessage || hasError; + }, [hasError, postingMessage]) + const inputRef = useRef(null); + const truncateFileName = (name: string, maxLength = 30) => { + if (name.length <= maxLength) { + return name; + } + const halfLength = Math.floor((maxLength - 3) / 2); + return `${name.slice(0, halfLength)}...${name.slice(-halfLength)}`; + }; + + const extensionToLanguageMap: { [key: string]: string } = { + 'js': 'javascript', + 'ts': 'typescript', + 'md': 'markdown', + 'txt': 'plaintext', + // 必要に応じて他の拡張子と言語のマッピングを追加 + }; + + const getFileLanguage = (filename: string) => { + const lastDotIndex = filename.lastIndexOf('.'); + if (lastDotIndex === -1) { + return ''; + } + + const ext = filename.slice(lastDotIndex + 1).toLowerCase(); + const lang = extensionToLanguageMap[ext]; + + return lang === undefined ? ext : lang; + }; + const sendContent = useCallback(() => { + const filesString = textFiles.length > 0 ? textFiles.map(file => `\`\`\`${getFileLanguage(file.name)}:${file.name}\n${file.content}\n\`\`\`\n\n`).join('') : undefined; + let message = "" + if (filesString !== undefined){ + message = content + `\n\n` + filesString + }else{ + message = content + } + props.onSend( - content, + message, !disabledImageUpload && base64EncodedImages.length > 0 ? base64EncodedImages : undefined ); setContent(''); clearBase64EncodedImages(); + clearTextFiles(); }, [ base64EncodedImages, + textFiles, clearBase64EncodedImages, + clearTextFiles, content, disabledImageUpload, props, @@ -169,6 +246,19 @@ const InputChatContent: React.FC = (props) => { [pushBase64EncodedImage] ); + const handleFileRead = useCallback( + (file: File) => { + const reader = new FileReader(); + reader.onload = () => { + if (typeof reader.result === 'string') { + pushTextFile({ name: file.name, content: reader.result }); + } + }; + reader.readAsText(file); + }, + [pushTextFile] + ); + useEffect(() => { const currentElem = inputRef?.current; const keypressListener = (e: DocumentEventMap['keypress']) => { @@ -211,11 +301,15 @@ const InputChatContent: React.FC = (props) => { for (let i = 0; i < fileList.length; i++) { const file = fileList.item(i); if (file) { - encodeAndPushImage(file); + if (file.type.startsWith('text/') || file.name.endsWith('.md') || file.name.endsWith('.ts')) { + handleFileRead(file); + } else { + encodeAndPushImage(file); + } } } }, - [encodeAndPushImage] + [encodeAndPushImage, handleFileRead] ); const onDragOver: React.DragEventHandler = useCallback( @@ -266,7 +360,7 @@ const InputChatContent: React.FC = (props) => { @@ -321,15 +415,47 @@ const InputChatContent: React.FC = (props) => {
)} - {messages.length > 1 && ( - + {textFiles.length > 0 && ( +
+ {textFiles.map((file, idx) => ( +
+ +
{truncateFileName(file.name)}
+ { + removeTextFile(idx); + }} + > + + +
+ ))} +
+ )} + {(getShouldContinue() || messages.length > 1) && ( +
+ {getShouldContinue() && ( + + )} + {messages.length > 1 && ( + + )} +
)} diff --git a/frontend/src/hooks/useChat.ts b/frontend/src/hooks/useChat.ts index 6811b59b5..446d9bb8d 100644 --- a/frontend/src/hooks/useChat.ts +++ b/frontend/src/hooks/useChat.ts @@ -81,6 +81,9 @@ const useChatState = create<{ setIsGeneratedTitle: (b: boolean) => void; getPostedModel: () => Model; shouldUpdateMessages: (currentConversation: Conversation) => boolean; + shouldCotinue: boolean; + setShouldContinue: (b: boolean) => void; + getShouldContinue: () => boolean; }>((set, get) => { return { conversationId: '', @@ -224,6 +227,15 @@ const useChatState = create<{ get().currentMessageId !== currentConversation.lastMessageId ); }, + getShouldContinue: () => { + return get().shouldCotinue; + }, + setShouldContinue: (b) => { + set(() => ({ + shouldCotinue: b, + })); + }, + shouldCotinue: false, }; }); @@ -252,6 +264,8 @@ const useChat = () => { setRelatedDocuments, moveRelatedDocuments, shouldUpdateMessages, + getShouldContinue, + setShouldContinue, } = useChatState(); const { open: openSnackbar } = useSnackbar(); const navigate = useNavigate(); @@ -300,6 +314,9 @@ const useChat = () => { moveRelatedDocuments(NEW_MESSAGE_ID.ASSISTANT, data.lastMessageId); } } + if (data && data.shouldContinue !== getShouldContinue()) { + setShouldContinue(data.shouldContinue); + } // eslint-disable-next-line react-hooks/exhaustive-deps }, [conversationId, data]); @@ -486,6 +503,61 @@ const useChat = () => { } }; + /** + * 生成を続ける + */ + const continueGenerate = (params?: { + messageId?: string; + bot?: BotInputType; + }) => { + setPostingMessage(true); + + const messageContent: MessageContent = { + content: [ + { + body: '', + contentType: 'text', + }, + ], + model: getPostedModel(), + role: 'user', + feedback: null, + usedChunks: null, + }; + const input: PostMessageRequest = { + conversationId: conversationId, + message: { + ...messageContent, + parentMessageId: messages[messages.length - 1].id, + }, + botId: params?.bot?.botId, + }; + + const currentContentBody = messages[messages.length - 1].content[0].body; + const currentMessage = messages[messages.length - 1]; + + postStreaming({ + input, + dispatch: (c: string) => { + editMessage(conversationId, currentMessage.id, currentContentBody + c); + }, + thinkingDispatch: (event) => { + send({ type: event }); + }, + }) + .then(() => { + mutate(); + }) + .catch((e) => { + console.error(e); + // setCurrentMessageId(NEW_MESSAGE_ID.USER); + // removeMessage(conversationId, NEW_MESSAGE_ID.ASSISTANT); + }) + .finally(() => { + setPostingMessage(false); + }); + }; + /** * 再生成 * @param props content: 内容を上書きしたい場合に設定 messageId: 再生成対象のmessageId botId: ボットの場合は設定する @@ -620,6 +692,8 @@ const useChat = () => { postChat, regenerate, getPostedModel, + getShouldContinue, + continueGenerate, // エラーのリトライ retryPostChat: (params: { content?: string; bot?: BotInputType }) => { const length_ = messages.length; diff --git a/frontend/src/i18n/de/index.ts b/frontend/src/i18n/de/index.ts index d7e2b71e4..7bcdd1422 100644 --- a/frontend/src/i18n/de/index.ts +++ b/frontend/src/i18n/de/index.ts @@ -172,6 +172,7 @@ Wie würden Sie diese E-Mail kategorisieren?`, signOut: 'Abmelden', close: 'Schließen', add: 'Hinzufügen', + continue: 'Weiter generieren', }, input: { hint: { diff --git a/frontend/src/i18n/en/index.ts b/frontend/src/i18n/en/index.ts index 8400d25ac..51242f4b9 100644 --- a/frontend/src/i18n/en/index.ts +++ b/frontend/src/i18n/en/index.ts @@ -397,6 +397,7 @@ How would you categorize this email?`, signOut: 'Sign out', close: 'Close', add: 'Add', + continue: 'Continue to generate', }, input: { hint: { diff --git a/frontend/src/i18n/es/index.ts b/frontend/src/i18n/es/index.ts index 18bfe048e..ac9d1b22a 100644 --- a/frontend/src/i18n/es/index.ts +++ b/frontend/src/i18n/es/index.ts @@ -309,6 +309,7 @@ Las categorías de clasificación son: signOut: 'Cerrar sesión', close: 'Cerrar', add: 'Agregar', + continue: 'Seguir generando', }, input: { hint: { diff --git a/frontend/src/i18n/fr/index.ts b/frontend/src/i18n/fr/index.ts index 302b6109b..b50537656 100644 --- a/frontend/src/i18n/fr/index.ts +++ b/frontend/src/i18n/fr/index.ts @@ -172,6 +172,7 @@ Comment catégoriseriez-vous cet e-mail ?`, signOut: 'Se déconnecter', close: 'Fermer', add: 'Ajouter', + continue: 'Continuer à générer', }, input: { hint: { diff --git a/frontend/src/i18n/it/index.ts b/frontend/src/i18n/it/index.ts index ed44eb8c4..227862e9d 100644 --- a/frontend/src/i18n/it/index.ts +++ b/frontend/src/i18n/it/index.ts @@ -332,6 +332,7 @@ Come classificheresti questa email?`, signOut: 'Disconnessione', close: 'Chiudi', add: 'Aggiungi', + continue: 'Continuare a generare', }, input: { hint: { diff --git a/frontend/src/i18n/ja/index.ts b/frontend/src/i18n/ja/index.ts index c55378d33..b7a1d5112 100644 --- a/frontend/src/i18n/ja/index.ts +++ b/frontend/src/i18n/ja/index.ts @@ -400,6 +400,7 @@ const translation = { signOut: 'サインアウト', close: '閉じる', add: '追加', + continue: '生成を続ける', }, input: { hint: { diff --git a/frontend/src/i18n/zh-hans/index.ts b/frontend/src/i18n/zh-hans/index.ts index 8b25e6d54..bc64a2a4e 100644 --- a/frontend/src/i18n/zh-hans/index.ts +++ b/frontend/src/i18n/zh-hans/index.ts @@ -329,6 +329,7 @@ signOut: '退出登录', close: '关闭', add: '新增', + continue: '继续生成', }, input: { hint: { diff --git a/frontend/src/i18n/zh-hant/index.ts b/frontend/src/i18n/zh-hant/index.ts index 251a4bf83..4ea47f362 100644 --- a/frontend/src/i18n/zh-hant/index.ts +++ b/frontend/src/i18n/zh-hant/index.ts @@ -329,6 +329,7 @@ signOut: '登出', close: '關閉', add: '新增', + continue: '繼續生成', }, input: { hint: { diff --git a/frontend/src/index.css b/frontend/src/index.css index ea235d31d..4324a8b12 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -8,3 +8,18 @@ --amplify-components-button-primary-background-color ) !important; } + +.file-name { + text-align: center; + max-width: 8rem; /* ファイル名の幅 */ + white-space: pre-wrap; /* 折り返しを有効にする */ + word-wrap: break-word; /* 長い単語の途中で折り返す */ +} + +.code-container { + padding: 3px; +} +.code-header { + margin-left: 18px; + margin-top: 5px; +} diff --git a/frontend/src/pages/ChatPage.tsx b/frontend/src/pages/ChatPage.tsx index c305059a0..9bc699427 100644 --- a/frontend/src/pages/ChatPage.tsx +++ b/frontend/src/pages/ChatPage.tsx @@ -49,6 +49,7 @@ const ChatPage: React.FC = () => { retryPostChat, setCurrentMessageId, regenerate, + continueGenerate, getPostedModel, loadingConversation, } = useChat(); @@ -160,6 +161,10 @@ const ChatPage: React.FC = () => { }); }, [inputBotParams, regenerate]); + const onContinueGenerate = useCallback(()=>{ + continueGenerate({bot: inputBotParams}); + }, [inputBotParams, continueGenerate]) + useEffect(() => { if (messages.length > 0) { scrollToBottom(); @@ -390,6 +395,7 @@ const ChatPage: React.FC = () => { } onSend={onSend} onRegenerate={onRegenerate} + continueGenerate={onContinueGenerate} /> )}