From 19bb3d4f8f461bca1f0f40f23bf19222bfc81d07 Mon Sep 17 00:00:00 2001 From: Zihao Zhou Date: Thu, 22 Aug 2024 02:05:46 +0000 Subject: [PATCH] Use unifying string and vision converter --- pkg/adapter/struct.go | 79 ++++++++----------------------------------- 1 file changed, 15 insertions(+), 64 deletions(-) diff --git a/pkg/adapter/struct.go b/pkg/adapter/struct.go index db79de9..5082c6c 100644 --- a/pkg/adapter/struct.go +++ b/pkg/adapter/struct.go @@ -13,22 +13,6 @@ type ChatCompletionMessage struct { Content json.RawMessage `json:"content"` } -func (m *ChatCompletionMessage) stringContent() (str string, err error) { - err = json.Unmarshal(m.Content, &str) - if err != nil { - return "", errors.Wrap(err, "json.Unmarshal") - } - return -} - -func (m *ChatCompletionMessage) multiContent() (parts []openai.ChatMessagePart, err error) { - err = json.Unmarshal(m.Content, &parts) - if err != nil { - return nil, errors.Wrap(err, "json.Unmarshal") - } - return -} - // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model" binding:"required"` @@ -42,21 +26,29 @@ type ChatCompletionRequest struct { } func (req *ChatCompletionRequest) ToGenaiMessages() ([]*genai.Content, error) { - if req.Model == Gemini1Dot5ProV || req.Model == openai.GPT4VisionPreview { - return req.toVisionGenaiContent() - } else if req.Model == TextEmbedding004 || req.Model == string(openai.AdaEmbeddingV2) { + if req.Model == TextEmbedding004 || req.Model == string(openai.AdaEmbeddingV2) { return nil, errors.New("Chat Completion is not supported for embedding model") } - return req.toStringGenaiContent() + return req.toVisionGenaiContent() } func (req *ChatCompletionRequest) toVisionGenaiContent() ([]*genai.Content, error) { content := make([]*genai.Content, 0, len(req.Messages)) for _, message := range req.Messages { - parts, err := message.multiContent() - if err != nil { - return nil, errors.Wrap(err, "message.multiContent") + var parts []openai.ChatMessagePart + + // Attempt to unmarshal into a slice of parts + if err := json.Unmarshal(message.Content, &parts); err != nil { + // If it fails, try unmarshalling into a single string + var singleString string + if err := json.Unmarshal(message.Content, &singleString); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal message content") + } + // Convert single string to a part + parts = []openai.ChatMessagePart{ + {Type: openai.ChatMessagePartTypeText, Text: singleString}, + } } prompt := make([]genai.Part, 0, len(parts)) @@ -103,47 +95,6 @@ func (req *ChatCompletionRequest) toVisionGenaiContent() ([]*genai.Content, erro return content, nil } -func (req *ChatCompletionRequest) toStringGenaiContent() ([]*genai.Content, error) { - content := make([]*genai.Content, 0, len(req.Messages)) - for _, message := range req.Messages { - str, err := message.stringContent() - if err != nil { - return nil, errors.Wrap(err, "message.stringContent") - } - - prompt := []genai.Part{ - genai.Text(str), - } - - switch message.Role { - case openai.ChatMessageRoleSystem: - content = append(content, []*genai.Content{ - { - Parts: prompt, - Role: genaiRoleUser, - }, - { - Parts: []genai.Part{ - genai.Text(""), - }, - Role: genaiRoleModel, - }, - }...) - case openai.ChatMessageRoleAssistant: - content = append(content, &genai.Content{ - Parts: prompt, - Role: genaiRoleModel, - }) - case openai.ChatMessageRoleUser: - content = append(content, &genai.Content{ - Parts: prompt, - Role: genaiRoleUser, - }) - } - } - return content, nil -} - type CompletionChoice struct { Index int `json:"index"` Delta struct {