Skip to content

Commit

Permalink
[fix] 챗봇 로직 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
hysong4u committed Oct 4, 2024
1 parent 460a8ac commit 65d2389
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 224 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
package com.example.hackdive.domain.message.contoller;

import com.example.hackdive.domain.message.dto.MessageInput;
import com.example.hackdive.domain.message.entity.Message;
import com.example.hackdive.domain.message.service.MessageService;
import com.example.hackdive.global.common.SuccessResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.AllArgsConstructor;

import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;

@RestController
@RequestMapping("/api/message")
Expand All @@ -22,43 +20,30 @@ public class MessageController {
private final MessageService messageService;

// AI 채팅 반환(stream)
@GetMapping("/stream/{workspaceId}/{isFirst}")
public SseEmitter streamMessages(@PathVariable Long workspaceId, @PathVariable boolean isFirst) {
SseEmitter emitter = new SseEmitter(60000L);
@GetMapping("/recieve/{workspaceId}/{isFirst}")
public SseEmitter streamMessages(@PathVariable("workspaceId") Long workspaceId, @PathVariable("isFirst") boolean isFirst) {
SseEmitter emitter = new SseEmitter();

CompletableFuture.runAsync(() -> {
try {
messageService.streamMessages(workspaceId, isFirst)
.subscribe(
content -> {
try {
messageService.addEvent(emitter, content);
} catch (IOException e) {
throw new RuntimeException(e);
}
},
emitter::completeWithError,
emitter::complete
);
} catch (Exception e) {
emitter.completeWithError(e);
}
});
Flux<String> str = messageService.streamMessages(workspaceId, isFirst);

str.subscribe(
data -> {
try {
emitter.send(data);
} catch (IOException e) {
emitter.completeWithError(e);
}
},
emitter::completeWithError,
emitter::complete
);
return emitter;
}


// AI 채팅 반환
@GetMapping("/sync/{workspaceId}/{isFirst}")
public String getGptOutputSync(@PathVariable("workspaceId") Long workspaceId, @PathVariable("isFirst") boolean isFirst) {
return messageService.getGptOutputSync(workspaceId, isFirst);
}

// 유저 채팅 전송
@PostMapping("/send")
public ResponseEntity<SuccessResponse<?>> sendMessage (@RequestBody MessageInput message) {
messageService.saveMessage(message);
@PostMapping("/send/{workspaceId}")
public ResponseEntity<SuccessResponse<?>> sendMessage(@PathVariable Long workspaceId, @RequestBody String message) throws JsonProcessingException {
messageService.saveMessage(workspaceId, message);
return SuccessResponse.ok("");
}

Expand All @@ -69,4 +54,13 @@ public ResponseEntity<SuccessResponse<?>> getChatPage(@PathVariable("workspaceId
return SuccessResponse.ok(messages);
}


// AI 채팅 반환(sync)
/*
@GetMapping("/sync/{workspaceId}/{isFirst}")
public String getGptOutputSync(@PathVariable("workspaceId") Long workspaceId, @PathVariable("isFirst") boolean isFirst) {
return messageService.getGptOutputSync(workspaceId, isFirst);
}
*/

}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.example.hackdive.domain.message.service;

import com.example.hackdive.domain.message.dto.GptRequestDTO;
import com.example.hackdive.domain.message.dto.MessageInput;
import com.example.hackdive.domain.message.entity.Message;
import com.example.hackdive.domain.message.repository.MessageRepository;
import com.example.hackdive.domain.workspace.entity.Workspace;
import com.example.hackdive.domain.workspace.repository.WorkspaceRepository;
import com.example.hackdive.global.cofig.GPTConfig;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Value;
Expand Down Expand Up @@ -36,14 +36,18 @@ public MessageService(MessageRepository messageRepository, WorkspaceRepository w
this.workspaceRepository = workspaceRepository;
}

public void saveMessage(MessageInput message) {
Workspace workspace = workspaceRepository.findById(message.getWorkspaceId())
.orElseThrow(() -> new RuntimeException("No workspace id " + message.getWorkspaceId()));
public void saveMessage(Long workspaceId, String message) throws JsonProcessingException {
Workspace workspace = workspaceRepository.findById(workspaceId)
.orElseThrow(() -> new RuntimeException("No workspace id " + workspaceId));

ObjectMapper objectMapper = new ObjectMapper();
JsonNode jsonNode = objectMapper.readTree(message);
String content = jsonNode.get("message").asText();

Message newMessage = Message.builder()
.workspace(workspace)
.role(message.getRole())
.content(message.getContent())
.role(GPTConfig.ROLE_USER)
.content(content)
.createdAt(LocalDateTime.now())
.build();

Expand All @@ -57,44 +61,23 @@ public List<Message> getAllMessage(Long workspaceId) {
return messageRepository.findAllByWorkspaceOrderByCreatedAtDesc(workspace);
}


public List<Message> getLLMInputs(Workspace workspace, boolean isFirst) {
List<Message> messages = messageRepository.findAllByWorkspaceOrderByCreatedAtDesc(workspace);
if (messages == null) {
throw new RuntimeException("The Messages is Null");
}

List<Message> parsedDatas = new ArrayList<>();
Message systemMessage = Message.builder()
.createdAt(LocalDateTime.now())
.content(GPTConfig.getSystemPrompts(isFirst))
.role("assistant")
.workspace(workspace)
.build();
parsedDatas.add(systemMessage);

if (!isFirst) {
for (int i = messages.size() - 1; i >= 0; i--) {
parsedDatas.add(messages.get(i));
}
}

return parsedDatas;
}

public String getGptOutputSync(Long workspaceId, boolean isFirst) {
Workspace workspace = workspaceRepository.findById(workspaceId)
.orElseThrow(() -> new RuntimeException("No workspace id " + workspaceId));

List<Message> inputMessages = getLLMInputs(workspace, isFirst);
String gptResponse = getResponseSync(inputMessages);

saveMessage(MessageInput.builder()
.content(gptResponse)
.role(GPTConfig.ROLE_ASSISTANT)
.workspaceId(workspaceId)
.build());

return gptResponse;
}

public Flux<String> getResponse(List<Message> messages) {
WebClient webClient = WebClient.builder()
Expand All @@ -119,19 +102,62 @@ public Flux<String> getResponse(List<Message> messages) {
}

public String extractContent(String jsonEvent) {
if ("DONE".equals(jsonEvent)) {
return "";
}
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(jsonEvent);
return node.at("/choices/0/message/content").asText();
return node.at("/choices/0/delta/content").asText();
} catch (IOException e) {
System.err.println("Error processing JSON: " + e.getMessage());
return "";
}
}

public Flux<String> streamMessages(Long workspaceId, boolean isFirst) {
Workspace workspace = workspaceRepository.findById(workspaceId)
.orElseThrow(() -> new RuntimeException("No workspace id " + workspaceId));

List<Message> inputMessages = getLLMInputs(workspace, isFirst);
StringBuilder accumulatedContent = new StringBuilder();

return Flux.create(sink -> {
if (isFirst) {
String systemPrompt = GPTConfig.getSystemPrompts(true);
accumulatedContent.append(systemPrompt);

messageRepository.save(Message.builder()
.workspace(workspace)
.role(GPTConfig.ROLE_ASSISTANT)
.content(systemPrompt)
.createdAt(LocalDateTime.now())
.build());
sink.next(systemPrompt);
return;
}

Flux<String> eventStream = getResponse(inputMessages);

eventStream.subscribe(
content -> {
String extractedContent = extractContent(content);
accumulatedContent.append(extractedContent);
sink.next(extractedContent);
},
sink::error,
() -> {
messageRepository.save(Message.builder()
.workspace(workspace)
.role(GPTConfig.ROLE_ASSISTANT)
.content(accumulatedContent.toString())
.createdAt(LocalDateTime.now())
.build());
sink.complete();
});

new SseEmitter().onTimeout(sink::complete);
});
}

/* 동기(sync) 호출
public String getResponseSync(List<Message> messages) {
WebClient webClient = WebClient.builder()
.baseUrl(GPTConfig.CHAT_URL)
Expand All @@ -158,33 +184,20 @@ public String getResponseSync(List<Message> messages) {
throw new RuntimeException("Error getting GPT response: " + e.getMessage());
}
}

public void addEvent(SseEmitter emitter, String content) throws IOException {
SseEmitter.SseEventBuilder eventBuilder = SseEmitter.event().data(content).name("message");
emitter.send(eventBuilder);
}

public Flux<String> streamMessages(Long workspaceId, boolean isFirst) {
public String getGptOutputSync(Long workspaceId, boolean isFirst) {
Workspace workspace = workspaceRepository.findById(workspaceId)
.orElseThrow(() -> new RuntimeException("No workspace id " + workspaceId));
List<Message> inputMessages = getLLMInputs(workspace, isFirst);
String gptResponse = getResponseSync(inputMessages);
return Flux.create(sink -> {
Flux<String> eventStream = getResponse(inputMessages);
eventStream.subscribe(
content -> {
try {
addEvent(new SseEmitter(), content);
sink.next(content);
} catch (IOException e) {
sink.error(e);
}
},
sink::error,
sink::complete);
saveMessage(MessageInput.builder()
.content(gptResponse)
.role(GPTConfig.ROLE_ASSISTANT)
.workspaceId(workspaceId)
.build());
new SseEmitter().onTimeout(sink::complete);
});
return gptResponse;
}
*/
}
Loading

0 comments on commit 65d2389

Please sign in to comment.