diff --git a/backend/app/repositories/conversation.py b/backend/app/repositories/conversation.py index fb1170b67..b30775f02 100644 --- a/backend/app/repositories/conversation.py +++ b/backend/app/repositories/conversation.py @@ -48,6 +48,7 @@ def store_conversation( # Ref: https://stackoverflow.com/questions/63026648/errormessage-class-decimal-inexact-class-decimal-rounded-while "TotalPrice": decimal(str(conversation.total_price)), "LastMessageId": conversation.last_message_id, + "ShouldContinue": conversation.should_continue, } if conversation.bot_id: @@ -236,6 +237,7 @@ def find_conversation_by_id(user_id: str, conversation_id: str) -> ConversationM }, last_message_id=item["LastMessageId"], bot_id=item["BotId"] if "BotId" in item else None, + should_continue=item.get("ShouldContinue", False), ) logger.info(f"Found conversation: {conv}") return conv diff --git a/backend/app/repositories/models/conversation.py b/backend/app/repositories/models/conversation.py index 1118a25a4..12b7fb637 100644 --- a/backend/app/repositories/models/conversation.py +++ b/backend/app/repositories/models/conversation.py @@ -43,6 +43,7 @@ class ConversationModel(BaseModel): message_map: dict[str, MessageModel] last_message_id: str bot_id: str | None + should_continue: bool class ConversationMeta(BaseModel): diff --git a/backend/app/routes/published_api.py b/backend/app/routes/published_api.py index eaa77fa58..05f99e4de 100644 --- a/backend/app/routes/published_api.py +++ b/backend/app/routes/published_api.py @@ -56,6 +56,7 @@ def post_message(request: Request, message_input: ChatInputWithoutBotId): message_id=response_message_id, ), bot_id=bot_id, + continue_generate=message_input.continue_generate, ) try: diff --git a/backend/app/routes/schemas/conversation.py b/backend/app/routes/schemas/conversation.py index 994f08d58..c5dc0dc58 100644 --- a/backend/app/routes/schemas/conversation.py +++ b/backend/app/routes/schemas/conversation.py @@ -82,6 +82,7 @@ class ChatInput(BaseSchema): conversation_id: str message: MessageInput bot_id: str | None = Field(None) + continue_generate: bool = Field(False) class ChatOutput(BaseSchema): @@ -113,6 +114,7 @@ class Conversation(BaseSchema): message_map: dict[str, MessageOutput] last_message_id: str bot_id: str | None + should_continue: bool class NewTitleInput(BaseSchema): diff --git a/backend/app/routes/schemas/published_api.py b/backend/app/routes/schemas/published_api.py index 4d67166f6..e81ce1aa1 100644 --- a/backend/app/routes/schemas/published_api.py +++ b/backend/app/routes/schemas/published_api.py @@ -15,6 +15,7 @@ class ChatInputWithoutBotId(BaseSchema): If not provided, new conversation will be generated.""", ) message: MessageInputWithoutMessageId + continue_generate: bool = Field(False) class ChatOutputWithoutBotId(BaseSchema): diff --git a/backend/app/stream.py b/backend/app/stream.py index c3a5d8e58..fd0b3dba2 100644 --- a/backend/app/stream.py +++ b/backend/app/stream.py @@ -95,7 +95,7 @@ def run(self, args: dict): ) response = self.on_stop( OnStopInput( - full_token=concatenated, + full_token=concatenated.rstrip(), stop_reason=stop_reason, input_token_count=input_token_count, output_token_count=output_token_count, @@ -134,7 +134,7 @@ def run(self, args: dict): ) res = self.on_stop( OnStopInput( - full_token=concatenated, + full_token=concatenated.rstrip(), stop_reason=stop_reason, input_token_count=input_token_count, output_token_count=output_token_count, diff --git a/backend/app/usecases/chat.py b/backend/app/usecases/chat.py index e6bd5ee42..f25a3b596 100644 --- a/backend/app/usecases/chat.py +++ b/backend/app/usecases/chat.py @@ -180,6 +180,7 @@ def prepare_conversation( message_map=initial_message_map, last_message_id="", bot_id=chat_input.bot_id, + should_continue=False, ) # Append user chat input to the conversation @@ -187,26 +188,28 @@ def prepare_conversation( message_id = chat_input.message.message_id else: message_id = str(ULID()) - new_message = MessageModel( - role=chat_input.message.role, - content=[ - ContentModel( - content_type=c.content_type, - media_type=c.media_type, - body=c.body, - ) - for c in chat_input.message.content - ], - model=chat_input.message.model, - children=[], - parent=parent_id, - create_time=current_time, - feedback=None, - used_chunks=None, - thinking_log=None, - ) - conversation.message_map[message_id] = new_message - conversation.message_map[parent_id].children.append(message_id) # type: ignore + # If the "Generate continue" button is pressed, a new_message is not generated. + if not chat_input.continue_generate: + new_message = MessageModel( + role=chat_input.message.role, + content=[ + ContentModel( + content_type=c.content_type, + media_type=c.media_type, + body=c.body, + ) + for c in chat_input.message.content + ], + model=chat_input.message.model, + children=[], + parent=parent_id, + create_time=current_time, + feedback=None, + used_chunks=None, + thinking_log=None, + ) + conversation.message_map[message_id] = new_message + conversation.message_map[parent_id].children.append(message_id) # type: ignore return (message_id, conversation, bot) @@ -335,7 +338,9 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput: messages = trace_to_root( node_id=chat_input.message.parent_message_id, message_map=message_map ) - messages.append(chat_input.message) # type: ignore + + if not chat_input.continue_generate: + messages.append(chat_input.message) # type: ignore # Create payload to invoke Bedrock args = compose_args( @@ -357,6 +362,8 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput: response = get_bedrock_response(args) # type: ignore reply_txt = response["outputs"][0]["text"] # type: ignore + reply_txt = reply_txt.rstrip() + # Used chunks for RAG generation if bot and bot.display_retrieved_chunks and is_running_on_lambda(): if len(search_results) > 0: @@ -395,14 +402,23 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput: used_chunks=used_chunks, thinking_log=thinking_log, ) - conversation.message_map[assistant_msg_id] = message - # Append children to parent - conversation.message_map[user_msg_id].children.append(assistant_msg_id) - conversation.last_message_id = assistant_msg_id + if chat_input.continue_generate: + conversation.message_map[conversation.last_message_id].content[ + 0 + ].body += reply_txt + else: + conversation.message_map[assistant_msg_id] = message + + # Append children to parent + conversation.message_map[user_msg_id].children.append(assistant_msg_id) + conversation.last_message_id = assistant_msg_id conversation.total_price += price + # If continued, save the state + conversation.should_continue = response.stop_reason == "max_tokens" + # Store updated conversation store_conversation(user_id, conversation) # Update bot last used time @@ -570,6 +586,7 @@ def fetch_conversation(user_id: str, conversation_id: str) -> Conversation: last_message_id=conversation.last_message_id, message_map=message_map, bot_id=conversation.bot_id, + should_continue=conversation.should_continue, ) return output diff --git a/backend/app/websocket.py b/backend/app/websocket.py index ba11b57d7..8a21dd9aa 100644 --- a/backend/app/websocket.py +++ b/backend/app/websocket.py @@ -181,7 +181,9 @@ def process_chat_input( node_id=chat_input.message.parent_message_id, message_map=message_map, ) - messages.append(chat_input.message) # type: ignore + + if not chat_input.continue_generate: + messages.append(chat_input.message) # type: ignore args = compose_args( messages, @@ -203,43 +205,54 @@ def on_stream(token: str, **kwargs) -> None: gatewayapi.post_to_connection(ConnectionId=connection_id, Data=data_to_send) def on_stop(arg: OnStopInput, **kwargs) -> None: - used_chunks = None - if bot and bot.display_retrieved_chunks: - if len(search_results) > 0: - used_chunks = [] - for r in filter_used_results(arg.full_token, search_results): - content_type, source_link = get_source_link(r.source) - used_chunks.append( - ChunkModel( - content=r.content, - content_type=content_type, - source=source_link, - rank=r.rank, + if chat_input.continue_generate: + # For continue generate + conversation.message_map[conversation.last_message_id].content[ + 0 + ].body += arg.full_token + else: + used_chunks = None + if bot and bot.display_retrieved_chunks: + if len(search_results) > 0: + used_chunks = [] + for r in filter_used_results(arg.full_token, search_results): + content_type, source_link = get_source_link(r.source) + used_chunks.append( + ChunkModel( + content=r.content, + content_type=content_type, + source=source_link, + rank=r.rank, + ) ) - ) - # Append entire completion as the last message - assistant_msg_id = str(ULID()) - message = MessageModel( - role="assistant", - content=[ - ContentModel(content_type="text", body=arg.full_token, media_type=None) - ], - model=chat_input.message.model, - children=[], - parent=user_msg_id, - create_time=get_current_time(), - feedback=None, - used_chunks=used_chunks, - thinking_log=None, - ) - conversation.message_map[assistant_msg_id] = message - # Append children to parent - conversation.message_map[user_msg_id].children.append(assistant_msg_id) - conversation.last_message_id = assistant_msg_id + # Append entire completion as the last message + assistant_msg_id = str(ULID()) + message = MessageModel( + role="assistant", + content=[ + ContentModel( + content_type="text", body=arg.full_token, media_type=None + ) + ], + model=chat_input.message.model, + children=[], + parent=user_msg_id, + create_time=get_current_time(), + feedback=None, + used_chunks=used_chunks, + thinking_log=None, + ) + conversation.message_map[assistant_msg_id] = message + # Append children to parent + conversation.message_map[user_msg_id].children.append(assistant_msg_id) + conversation.last_message_id = assistant_msg_id conversation.total_price += arg.price + # If continued, save the state + conversation.should_continue = arg.stop_reason == "max_tokens" + # Store conversation before finish streaming so that front-end can avoid 404 issue store_conversation(user_id, conversation) last_data_to_send = json.dumps( diff --git a/backend/tests/test_repositories/test_conversation.py b/backend/tests/test_repositories/test_conversation.py index 395f859d1..857016b67 100644 --- a/backend/tests/test_repositories/test_conversation.py +++ b/backend/tests/test_repositories/test_conversation.py @@ -149,6 +149,7 @@ def test_store_and_find_conversation(self): }, last_message_id="x", bot_id=None, + should_continue=False, ) # Test storing conversation @@ -186,6 +187,7 @@ def test_store_and_find_conversation(self): self.assertEqual(found_conversation.last_message_id, "x") self.assertEqual(found_conversation.total_price, 100) self.assertEqual(found_conversation.bot_id, None) + self.assertEqual(found_conversation.should_continue, False) # Test update title response = change_conversation_title( @@ -260,6 +262,7 @@ def test_store_and_find_large_conversation(self): message_map=large_message_map, last_message_id="msg_9", bot_id=None, + should_continue=False, ) # Test storing large conversation with a small threshold @@ -275,6 +278,7 @@ def test_store_and_find_large_conversation(self): self.assertEqual(found_conversation.total_price, 200) self.assertEqual(found_conversation.last_message_id, "msg_9") self.assertEqual(found_conversation.bot_id, None) + self.assertEqual(found_conversation.should_continue, False) message_map = found_conversation.message_map self.assertEqual(len(message_map), 10) @@ -335,6 +339,7 @@ def setUp(self) -> None: }, last_message_id="x", bot_id=None, + should_continue=False, ) conversation2 = ConversationModel( id="2", @@ -365,6 +370,7 @@ def setUp(self) -> None: }, last_message_id="x", bot_id="1", + should_continue=False, ) bot1 = BotModel( id="1", diff --git a/backend/tests/test_usecases/test_chat.py b/backend/tests/test_usecases/test_chat.py index 03c41f4f6..f5f39e689 100644 --- a/backend/tests/test_usecases/test_chat.py +++ b/backend/tests/test_usecases/test_chat.py @@ -325,6 +325,7 @@ def setUp(self) -> None: ), }, bot_id=None, + should_continue=False, ), ) @@ -345,6 +346,7 @@ def test_continue_chat(self): message_id=None, ), bot_id=None, + continue_generate=False, ) output: ChatOutput = chat(user_id=self.user_id, chat_input=chat_input) self.output = output @@ -449,6 +451,7 @@ def setUp(self) -> None: ), }, bot_id=None, + should_continue=False, ), ) @@ -471,6 +474,7 @@ def test_chat(self): message_id=None, ), bot_id=None, + continue_generate=False, ) output: ChatOutput = chat(user_id=self.user_id, chat_input=chat_input) self.output = output @@ -497,6 +501,7 @@ def test_chat(self): message_id=None, ), bot_id=None, + continue_generate=False, ) output: ChatOutput = chat(user_id=self.user_id, chat_input=chat_input) self.output = output @@ -528,6 +533,7 @@ def setUp(self) -> None: message_id=None, ), bot_id=None, + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) print(output) @@ -599,6 +605,7 @@ def test_chat_with_private_bot(self): message_id=None, ), bot_id="private1", + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) print(output) @@ -625,6 +632,7 @@ def test_chat_with_private_bot(self): message_id=None, ), bot_id="private1", + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) print(output) @@ -646,6 +654,7 @@ def test_chat_with_private_bot(self): message_id=None, ), bot_id="private1", + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) @@ -671,6 +680,7 @@ def test_chat_with_public_bot(self): message_id=None, ), bot_id="public1", + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) @@ -693,6 +703,7 @@ def test_chat_with_public_bot(self): message_id=None, ), bot_id="private1", + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) print(output) @@ -717,6 +728,7 @@ def test_fetch_conversation(self): message_id=None, ), bot_id="private1", + continue_generate=False, ) output: ChatOutput = chat(user_id="user1", chat_input=chat_input) @@ -767,6 +779,7 @@ def test_agent_chat(self): message_id=None, ), bot_id=self.bot_id, + continue_generate=False, ) output: ChatOutput = chat(user_id=self.user_name, chat_input=chat_input) print(output.message.content[0].body) @@ -846,6 +859,7 @@ def test_insert_knowledge(self): }, bot_id="bot1", last_message_id="1-user", + continue_generate=False, ) conversation_with_context = insert_knowledge( conversation, results, display_citation=True @@ -872,6 +886,7 @@ def test_streaming_api(self): message_id=None, ), bot_id=None, + continue_generate=False, ) user_msg_id, conversation, bot = prepare_conversation("user1", chat_input) messages = trace_to_root( diff --git a/frontend/src/@types/conversation.d.ts b/frontend/src/@types/conversation.d.ts index bb26c704b..5d75aaee7 100644 --- a/frontend/src/@types/conversation.d.ts +++ b/frontend/src/@types/conversation.d.ts @@ -50,6 +50,7 @@ export type PostMessageRequest = { parentMessageId: null | string; }; botId?: string; + continueGenerate?: bool; }; export type PostMessageResponse = { @@ -86,6 +87,7 @@ export type MessageMap = { export type Conversation = ConversationMeta & { messageMap: MessageMap; + shouldContinue: boolean; }; export type PutFeedbackRequest = { diff --git a/frontend/src/components/InputChatContent.tsx b/frontend/src/components/InputChatContent.tsx index 7ce55f405..e255dc421 100644 --- a/frontend/src/components/InputChatContent.tsx +++ b/frontend/src/components/InputChatContent.tsx @@ -9,7 +9,7 @@ import ButtonSend from './ButtonSend'; import Textarea from './Textarea'; import useChat from '../hooks/useChat'; import Button from './Button'; -import { PiArrowsCounterClockwise, PiX } from 'react-icons/pi'; +import { PiArrowsCounterClockwise, PiX, PiArrowFatLineRight } from 'react-icons/pi'; import { TbPhotoPlus } from 'react-icons/tb'; import { useTranslation } from 'react-i18next'; import ButtonIcon from './ButtonIcon'; @@ -28,6 +28,7 @@ type Props = BaseProps & { dndMode?: boolean; onSend: (content: string, base64EncodedImages?: string[]) => void; onRegenerate: () => void; + continueGenerate: () => void; }; const MAX_IMAGE_WIDTH = 800; @@ -75,9 +76,10 @@ const useInputChatContentState = create<{ const InputChatContent: React.FC = (props) => { const { t } = useTranslation(); - const { postingMessage, hasError, messages } = useChat(); + const { postingMessage, hasError, messages, getShouldContinue } = useChat(); const { disabledImageUpload, model, acceptMediaType } = useModel(); - + const [shouldContinue, setShouldContinue] = useState(false); + const [content, setContent] = useState(''); const { base64EncodedImages, @@ -95,6 +97,14 @@ const InputChatContent: React.FC = (props) => { // eslint-disable-next-line react-hooks/exhaustive-deps }, []); + useEffect(() => { + const checkShouldContinue = async () => { + const result = await getShouldContinue(); + setShouldContinue(result); + }; + checkShouldContinue(); + }, [getShouldContinue, postingMessage, content, props, hasError]); + const disabledSend = useMemo(() => { return content === '' || props.disabledSend || hasError; }, [hasError, content, props.disabledSend]); @@ -102,6 +112,10 @@ const InputChatContent: React.FC = (props) => { const disabledRegenerate = useMemo(() => { return postingMessage || hasError; }, [hasError, postingMessage]); + + const disableContinue = useMemo(() => { + return postingMessage || hasError; + }, [hasError, postingMessage]) const inputRef = useRef(null); @@ -322,14 +336,25 @@ const InputChatContent: React.FC = (props) => { )} {messages.length > 1 && ( - +
+ {shouldContinue && !disableContinue && !props.disabled && ( + + )} + +
)} diff --git a/frontend/src/hooks/useChat.ts b/frontend/src/hooks/useChat.ts index 6811b59b5..48861a5bb 100644 --- a/frontend/src/hooks/useChat.ts +++ b/frontend/src/hooks/useChat.ts @@ -81,6 +81,9 @@ const useChatState = create<{ setIsGeneratedTitle: (b: boolean) => void; getPostedModel: () => Model; shouldUpdateMessages: (currentConversation: Conversation) => boolean; + shouldCotinue: boolean; + setShouldContinue: (b: boolean) => void; + getShouldContinue: () => boolean; }>((set, get) => { return { conversationId: '', @@ -224,6 +227,15 @@ const useChatState = create<{ get().currentMessageId !== currentConversation.lastMessageId ); }, + getShouldContinue: () => { + return get().shouldCotinue; + }, + setShouldContinue: (b) => { + set(() => ({ + shouldCotinue: b, + })); + }, + shouldCotinue: false, }; }); @@ -252,6 +264,8 @@ const useChat = () => { setRelatedDocuments, moveRelatedDocuments, shouldUpdateMessages, + getShouldContinue, + setShouldContinue, } = useChatState(); const { open: openSnackbar } = useSnackbar(); const navigate = useNavigate(); @@ -300,6 +314,9 @@ const useChat = () => { moveRelatedDocuments(NEW_MESSAGE_ID.ASSISTANT, data.lastMessageId); } } + if (data && data.shouldContinue !== getShouldContinue()) { + setShouldContinue(data.shouldContinue); + } // eslint-disable-next-line react-hooks/exhaustive-deps }, [conversationId, data]); @@ -486,6 +503,56 @@ const useChat = () => { } }; + /** + * Continue to generate + */ + const continueGenerate = (params?: { + messageId?: string; + bot?: BotInputType; + }) => { + setPostingMessage(true); + + const messageContent: MessageContent = { + content: [], + model: getPostedModel(), + role: 'user', + feedback: null, + usedChunks: null, + }; + const input: PostMessageRequest = { + conversationId: conversationId, + message: { + ...messageContent, + parentMessageId: messages[messages.length - 1].id, + }, + botId: params?.bot?.botId, + continueGenerate: true, + }; + + const currentContentBody = messages[messages.length - 1].content[0].body; + const currentMessage = messages[messages.length - 1]; + + // WARNING: Non-streaming is not supported from the UI side as it is planned to be DEPRICATED. + postStreaming({ + input, + dispatch: (c: string) => { + editMessage(conversationId, currentMessage.id, currentContentBody + c); + }, + thinkingDispatch: (event) => { + send({ type: event }); + }, + }) + .then(() => { + mutate(); + }) + .catch((e) => { + console.error(e); + }) + .finally(() => { + setPostingMessage(false); + }); + }; + /** * 再生成 * @param props content: 内容を上書きしたい場合に設定 messageId: 再生成対象のmessageId botId: ボットの場合は設定する @@ -620,6 +687,8 @@ const useChat = () => { postChat, regenerate, getPostedModel, + getShouldContinue, + continueGenerate, // エラーのリトライ retryPostChat: (params: { content?: string; bot?: BotInputType }) => { const length_ = messages.length; diff --git a/frontend/src/i18n/de/index.ts b/frontend/src/i18n/de/index.ts index d7e2b71e4..7bcdd1422 100644 --- a/frontend/src/i18n/de/index.ts +++ b/frontend/src/i18n/de/index.ts @@ -172,6 +172,7 @@ Wie würden Sie diese E-Mail kategorisieren?`, signOut: 'Abmelden', close: 'Schließen', add: 'Hinzufügen', + continue: 'Weiter generieren', }, input: { hint: { diff --git a/frontend/src/i18n/en/index.ts b/frontend/src/i18n/en/index.ts index 71f1609cd..7a658b3fa 100644 --- a/frontend/src/i18n/en/index.ts +++ b/frontend/src/i18n/en/index.ts @@ -405,6 +405,7 @@ How would you categorize this email?`, signOut: 'Sign out', close: 'Close', add: 'Add', + continue: 'Continue to generate', }, input: { hint: { diff --git a/frontend/src/i18n/es/index.ts b/frontend/src/i18n/es/index.ts index 18bfe048e..ac9d1b22a 100644 --- a/frontend/src/i18n/es/index.ts +++ b/frontend/src/i18n/es/index.ts @@ -309,6 +309,7 @@ Las categorías de clasificación son: signOut: 'Cerrar sesión', close: 'Cerrar', add: 'Agregar', + continue: 'Seguir generando', }, input: { hint: { diff --git a/frontend/src/i18n/fr/index.ts b/frontend/src/i18n/fr/index.ts index 302b6109b..b50537656 100644 --- a/frontend/src/i18n/fr/index.ts +++ b/frontend/src/i18n/fr/index.ts @@ -172,6 +172,7 @@ Comment catégoriseriez-vous cet e-mail ?`, signOut: 'Se déconnecter', close: 'Fermer', add: 'Ajouter', + continue: 'Continuer à générer', }, input: { hint: { diff --git a/frontend/src/i18n/it/index.ts b/frontend/src/i18n/it/index.ts index ed44eb8c4..227862e9d 100644 --- a/frontend/src/i18n/it/index.ts +++ b/frontend/src/i18n/it/index.ts @@ -332,6 +332,7 @@ Come classificheresti questa email?`, signOut: 'Disconnessione', close: 'Chiudi', add: 'Aggiungi', + continue: 'Continuare a generare', }, input: { hint: { diff --git a/frontend/src/i18n/ja/index.ts b/frontend/src/i18n/ja/index.ts index 798f875f4..0269f31f7 100644 --- a/frontend/src/i18n/ja/index.ts +++ b/frontend/src/i18n/ja/index.ts @@ -409,6 +409,7 @@ const translation = { signOut: 'サインアウト', close: '閉じる', add: '追加', + continue: '生成を続ける', }, input: { hint: { diff --git a/frontend/src/i18n/zh-hans/index.ts b/frontend/src/i18n/zh-hans/index.ts index 8b25e6d54..bc64a2a4e 100644 --- a/frontend/src/i18n/zh-hans/index.ts +++ b/frontend/src/i18n/zh-hans/index.ts @@ -329,6 +329,7 @@ signOut: '退出登录', close: '关闭', add: '新增', + continue: '继续生成', }, input: { hint: { diff --git a/frontend/src/i18n/zh-hant/index.ts b/frontend/src/i18n/zh-hant/index.ts index 251a4bf83..4ea47f362 100644 --- a/frontend/src/i18n/zh-hant/index.ts +++ b/frontend/src/i18n/zh-hant/index.ts @@ -329,6 +329,7 @@ signOut: '登出', close: '關閉', add: '新增', + continue: '繼續生成', }, input: { hint: { diff --git a/frontend/src/pages/ChatPage.tsx b/frontend/src/pages/ChatPage.tsx index 839c20808..8c4d404d6 100644 --- a/frontend/src/pages/ChatPage.tsx +++ b/frontend/src/pages/ChatPage.tsx @@ -56,6 +56,7 @@ const ChatPage: React.FC = () => { retryPostChat, setCurrentMessageId, regenerate, + continueGenerate, getPostedModel, loadingConversation, } = useChat(); @@ -167,6 +168,10 @@ const ChatPage: React.FC = () => { }); }, [inputBotParams, regenerate]); + const onContinueGenerate = useCallback(()=>{ + continueGenerate({bot: inputBotParams}); + }, [inputBotParams, continueGenerate]) + useLayoutEffect(() => { if (messages.length > 0) { scrollToBottom(); @@ -431,6 +436,7 @@ const ChatPage: React.FC = () => { } onSend={onSend} onRegenerate={onRegenerate} + continueGenerate={onContinueGenerate} /> )}