Skip to content

Commit

Permalink
✨ 允许知识库选择
Browse files Browse the repository at this point in the history
  • Loading branch information
twelvet-s committed Dec 16, 2024
1 parent ed95897 commit 5e33bc1
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;

/**
Expand All @@ -29,10 +26,15 @@ public class AIChatController {
@Autowired
private AIChatService aiChatService;

/**
* 回答用户问题
* @param messageDTO MessageDTO
* @return 流式输出回复
*/
@Operation(summary = "回答用户问题")
@PreAuthorize("@role.hasPermi('ai:chat')")
@PostMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<MessageVO> genAnswer(@RequestBody MessageDTO messageDTO) {
public Flux<MessageVO> chatStream(@RequestBody MessageDTO messageDTO) {
return aiChatService.chatStream(messageDTO);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,4 @@ public interface AIChatService {
*/
Flux<MessageVO> chatStream(MessageDTO messageDTO);

/**
* 格式化输出
* @param messageDTO
*/
Flux<MessageVO> formatTest(MessageDTO messageDTO);

}
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,21 @@ public Flux<MessageVO> chatStream(MessageDTO messageDTO) {
CompletableFuture<List<Message>> messagesCompletableFuture = CompletableFuture.supplyAsync(() -> {
// 加入历史对话
List<Message> messages = new ArrayList<>();
List<AiChatHistoryVO> aiChatHistoryList = aiChatHistoryService.selectAiChatHistoryListByUserId(userId,
aiModel.getMultiRound());
for (AiChatHistoryVO aiChatHistoryVO : aiChatHistoryList) {
RAGEnums.UserTypeEnums createByType = aiChatHistoryVO.getCreateByType();
String content = aiChatHistoryVO.getContent();
if (RAGEnums.UserTypeEnums.USER.equals(createByType)) {
messages.add(new UserMessage(content));
}
else if (RAGEnums.UserTypeEnums.AI.equals(createByType)) {
messages.add(new AssistantMessage(content));
}
else {
throw new TWTException("无法匹配对应的会话用户类型");
if (Boolean.TRUE.equals(messageDTO.getCarryContextFlag())) {
List<AiChatHistoryVO> aiChatHistoryList = aiChatHistoryService.selectAiChatHistoryListByUserId(userId,
aiModel.getMultiRound());
for (AiChatHistoryVO aiChatHistoryVO : aiChatHistoryList) {
RAGEnums.UserTypeEnums createByType = aiChatHistoryVO.getCreateByType();
String content = aiChatHistoryVO.getContent();
if (RAGEnums.UserTypeEnums.USER.equals(createByType)) {
messages.add(new UserMessage(content));
}
else if (RAGEnums.UserTypeEnums.AI.equals(createByType)) {
messages.add(new AssistantMessage(content));
}
else {
throw new TWTException("无法匹配对应的会话用户类型");
}
}
}
return messages;
Expand Down Expand Up @@ -199,6 +201,8 @@ else if (RAGEnums.UserTypeEnums.AI.equals(createByType)) {
// 储存AI回答
// 回复时间必须保证在用户提问时间之前(重新获取时间,并且增加1毫秒),保证排序
LocalDateTime replyNow = LocalDateTime.now().plusNanos(1_000_000);
// 生成唯一消息雪花ID
String aiMsgId = String.valueOf(YitIdHelper.nextId());
// ai回复内容
StringBuffer aiContent = new StringBuffer();

Expand Down Expand Up @@ -227,6 +231,7 @@ else if (RAGEnums.UserTypeEnums.AI.equals(createByType)) {
.map(chatResponse -> {
MessageVO messageVO = new MessageVO();
String content = chatResponse.getResult().getOutput().getContent();
messageVO.setMsgId(aiMsgId);
messageVO.setContent(content);
// 储存AI回复内容
aiContent.append(content);
Expand All @@ -236,8 +241,6 @@ else if (RAGEnums.UserTypeEnums.AI.equals(createByType)) {
if (Arrays.asList(SignalType.CANCEL, SignalType.ON_COMPLETE).contains(signalType)) { // 取消链接时或完成输出时
// 储存AI提问
AiChatHistoryDTO aiChatHistoryDTO = new AiChatHistoryDTO();
// 生成唯一消息雪花ID
String aiMsgId = String.valueOf(YitIdHelper.nextId());
aiChatHistoryDTO.setMsgId(aiMsgId);
aiChatHistoryDTO.setUserId(userId);
aiChatHistoryDTO.setSendUserId(userId);
Expand All @@ -253,32 +256,4 @@ else if (RAGEnums.UserTypeEnums.AI.equals(createByType)) {
});
}

/**
* 格式化输出
* @param messageDTO
* @return
*/
@Override
public Flux<MessageVO> formatTest(MessageDTO messageDTO) {
BeanOutputConverter<List<ActorsFilms>> converter = new BeanOutputConverter<>(
new ParameterizedTypeReference<List<ActorsFilms>>() {
});

return ChatClient
// 自定义使用不同的大模型
.create(dashScopeChatModel)
.prompt()
.user(u -> u.text("""
Generate the filmography for a random {actor}.
{format}
""").param("actor", messageDTO.getContent()).param("format", converter.getFormat()))
.stream()
.chatResponse()
.map(chatResponse -> {
MessageVO messageVO = new MessageVO();
messageVO.setContent(chatResponse.getResult().getOutput().getContent());
return messageVO;
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public class MessageDTO implements Serializable {
@Schema(description = "提问内容")
private String content;

@Schema(description = "是否携带上下文记忆")
private Boolean carryContextFlag;

public @NotNull(message = "知识库ID不能为空") Long getModelId() {
return modelId;
}
Expand All @@ -45,9 +48,20 @@ public void setContent(@NotBlank(message = "提问内容不能为空") String co
this.content = content;
}

public Boolean getCarryContextFlag() {
return carryContextFlag;
}

public void setCarryContextFlag(Boolean carryContextFlag) {
this.carryContextFlag = carryContextFlag;
}

@Override
public String toString() {
return "MessageDTO{" + "modelId=" + modelId + ", content='" + content + '\'' + '}';
return "MessageDTO{" +
"modelId=" + modelId +
", content='" + content + '\'' +
", carryContextFlag=" + carryContextFlag +
'}';
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,26 @@
@Schema(description = "AI助手聊天VO")
public class MessageVO {

/**
* 消息唯一ID
*/
@Schema(description = "消息唯一ID")
private String msgId;

/**
* 问题内容
*/
@Schema(description = "问题内容")
private String content;

public String getMsgId() {
return msgId;
}

public void setMsgId(String msgId) {
this.msgId = msgId;
}

public String getContent() {
return content;
}
Expand All @@ -21,7 +35,9 @@ public void setContent(String content) {

@Override
public String toString() {
return "MessageVO{" + "content='" + content + '\'' + '}';
return "MessageVO{" +
"msgId='" + msgId + '\'' +
", content='" + content + '\'' +
'}';
}

}

0 comments on commit 5e33bc1

Please sign in to comment.