diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index 071f4742..c08b1bd1 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -1,5 +1,6 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; +import { ChatCompletion } from "openai/resources/chat/completions"; import serverConfig from "@hoarder/shared/config"; import logger from "@hoarder/shared/logger"; @@ -41,6 +42,24 @@ class OpenAIInferenceClient implements InferenceClient { }); } + /** + * @param chatCompletion the chatCompletion object or the chatCompletion object as a string (e.g. when using OpenRouter in some occasions) + * @returns the messageContent extracted out of the chatCompletion object or out of the chatCompletion string + */ + extractResponse(chatCompletion: ChatCompletion | string): string { + let response: ChatCompletion; + if (typeof chatCompletion === "string") { + response = JSON.parse(chatCompletion) as ChatCompletion; + } else { + response = chatCompletion; + } + const responseContent = response.choices[0].message.content; + if (!responseContent) { + throw new Error(`Got no message content from OpenAI`); + } + return responseContent; + } + async inferFromText(prompt: string): Promise { const chatCompletion = await this.openAI.chat.completions.create({ messages: [{ role: "system", content: prompt }], @@ -48,10 +67,7 @@ class OpenAIInferenceClient implements InferenceClient { response_format: { type: "json_object" }, }); - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } + const response = this.extractResponse(chatCompletion); return { response, totalTokens: chatCompletion.usage?.total_tokens }; } @@ -81,10 +97,7 @@ class OpenAIInferenceClient implements InferenceClient { max_tokens: 2000, }); - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } + const response = this.extractResponse(chatCompletion); return { response, totalTokens: chatCompletion.usage?.total_tokens }; } }