From 9505b86ad45935827b62ade8f6029d3420e0d4ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?To=C3=A0n=20=C4=90o=C3=A0n?= Date: Thu, 18 Jul 2024 16:28:00 +0700 Subject: [PATCH] Add Google Gemini API Support (#1330) * Supports Google Gemini API * Fixed error when running svelte-check * update: from the contribution of nsarrazin --- .../configuration/models/providers/google.md | 36 ++++ package-lock.json | 10 + package.json | 1 + src/lib/server/endpoints/endpoints.ts | 3 + .../server/endpoints/google/endpointGenAI.ts | 171 ++++++++++++++++++ src/lib/server/models.ts | 2 + 6 files changed, 223 insertions(+) create mode 100644 src/lib/server/endpoints/google/endpointGenAI.ts diff --git a/docs/source/configuration/models/providers/google.md b/docs/source/configuration/models/providers/google.md index 9d0381e5d94..008baf0cdb4 100644 --- a/docs/source/configuration/models/providers/google.md +++ b/docs/source/configuration/models/providers/google.md @@ -44,6 +44,42 @@ MODELS=`[ } }] }] + } +]` +``` + +## GenAI + +Or use the Gemini API API provider [from](https://github.com/google-gemini/generative-ai-js#readme): + +> Make sure that you have an API key from Google Cloud Platform. To get an API key, follow the instructions [here](https://cloud.google.com/docs/authentication/api-keys). + +```ini +MODELS=`[ + { + "name": "gemini-1.5-flash", + "displayName": "Gemini Flash 1.5", + "multimodal": true, + "endpoints": [ + { + "type": "genai", + "apiKey": "abc...xyz" + } + ] + + // Optional + "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE", }, + { + "name": "gemini-1.5-pro", + "displayName": "Gemini Pro 1.5", + "multimodal": false, + "endpoints": [ + { + "type": "genai", + "apiKey": "abc...xyz" + } + ] + } ]` ``` diff --git a/package-lock.json b/package-lock.json index fa821ca2460..2eff8ac64e3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -96,6 +96,7 @@ "@anthropic-ai/sdk": "^0.17.1", "@anthropic-ai/vertex-sdk": "^0.3.0", "@google-cloud/vertexai": "^1.1.0", + "@google/generative-ai": "^0.14.1", "aws4fetch": "^1.0.17", "cohere-ai": "^7.9.0", "openai": "^4.44.0" @@ -1360,6 +1361,15 @@ "node": ">=18.0.0" } }, + "node_modules/@google/generative-ai": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.14.1.tgz", + "integrity": "sha512-pevEyZCb0Oc+dYNlSberW8oZBm4ofeTD5wN01TowQMhTwdAbGAnJMtQzoklh6Blq2AKsx8Ox6FWa44KioZLZiA==", + "optional": true, + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/@gradio/client": { "version": "0.19.4", "resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.19.4.tgz", diff --git a/package.json b/package.json index e72a3612ba3..74d8f383e2a 100644 --- a/package.json +++ b/package.json @@ -106,6 +106,7 @@ "@anthropic-ai/sdk": "^0.17.1", "@anthropic-ai/vertex-sdk": "^0.3.0", "@google-cloud/vertexai": "^1.1.0", + "@google/generative-ai": "^0.14.1", "aws4fetch": "^1.0.17", "cohere-ai": "^7.9.0", "openai": "^4.44.0" diff --git a/src/lib/server/endpoints/endpoints.ts b/src/lib/server/endpoints/endpoints.ts index e2970d57df1..d138df6ed17 100644 --- a/src/lib/server/endpoints/endpoints.ts +++ b/src/lib/server/endpoints/endpoints.ts @@ -8,6 +8,7 @@ import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai"; import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp"; import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama"; import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex"; +import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI"; import { endpointAnthropic, @@ -65,6 +66,7 @@ export const endpoints = { llamacpp: endpointLlamacpp, ollama: endpointOllama, vertex: endpointVertex, + genai: endpointGenAI, cloudflare: endpointCloudflare, cohere: endpointCohere, langserve: endpointLangserve, @@ -79,6 +81,7 @@ export const endpointSchema = z.discriminatedUnion("type", [ endpointLlamacppParametersSchema, endpointOllamaParametersSchema, endpointVertexParametersSchema, + endpointGenAIParametersSchema, endpointCloudflareParametersSchema, endpointCohereParametersSchema, endpointLangserveParametersSchema, diff --git a/src/lib/server/endpoints/google/endpointGenAI.ts b/src/lib/server/endpoints/google/endpointGenAI.ts new file mode 100644 index 00000000000..3a6de6fa675 --- /dev/null +++ b/src/lib/server/endpoints/google/endpointGenAI.ts @@ -0,0 +1,171 @@ +import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "@google/generative-ai"; +import type { Content, Part, TextPart } from "@google/generative-ai"; +import { z } from "zod"; +import type { Message, MessageFile } from "$lib/types/Message"; +import type { TextGenerationStreamOutput } from "@huggingface/inference"; +import type { Endpoint } from "../endpoints"; +import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; +import type { ImageProcessorOptions } from "../images"; + +export const endpointGenAIParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("genai"), + apiKey: z.string(), + safetyThreshold: z + .enum([ + HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + HarmBlockThreshold.BLOCK_NONE, + HarmBlockThreshold.BLOCK_ONLY_HIGH, + ]) + .optional(), + multimodal: z + .object({ + image: createImageProcessorOptionsValidator({ + supportedMimeTypes: ["image/png", "image/jpeg", "image/webp"], + preferredMimeType: "image/webp", + // The 4 / 3 compensates for the 33% increase in size when converting to base64 + maxSizeInMB: (5 / 4) * 3, + maxWidth: 4096, + maxHeight: 4096, + }), + }) + .default({}), +}); + +export function endpointGenAI(input: z.input): Endpoint { + const { model, apiKey, safetyThreshold, multimodal } = endpointGenAIParametersSchema.parse(input); + + const genAI = new GoogleGenerativeAI(apiKey); + + return async ({ messages, preprompt, generateSettings }) => { + const parameters = { ...model.parameters, ...generateSettings }; + + const generativeModel = genAI.getGenerativeModel({ + model: model.id ?? model.name, + safetySettings: safetyThreshold + ? [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: safetyThreshold, + }, + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: safetyThreshold, + }, + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: safetyThreshold, + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: safetyThreshold, + }, + { + category: HarmCategory.HARM_CATEGORY_UNSPECIFIED, + threshold: safetyThreshold, + }, + ] + : undefined, + generationConfig: { + maxOutputTokens: parameters?.max_new_tokens ?? 4096, + stopSequences: parameters?.stop, + temperature: parameters?.temperature ?? 1, + }, + }); + + let systemMessage = preprompt; + if (messages[0].from === "system") { + systemMessage = messages[0].content; + messages.shift(); + } + + const genAIMessages = await Promise.all( + messages.map(async ({ from, content, files }: Omit): Promise => { + return { + role: from === "user" ? "user" : "model", + parts: [ + ...(await Promise.all( + (files ?? []).map((file) => fileToImageBlock(file, multimodal.image)) + )), + { text: content }, + ], + }; + }) + ); + + const result = await generativeModel.generateContentStream({ + contents: genAIMessages, + systemInstruction: + systemMessage && systemMessage.trim() !== "" + ? { + role: "system", + parts: [{ text: systemMessage }], + } + : undefined, + }); + + let tokenId = 0; + return (async function* () { + let generatedText = ""; + + for await (const data of result.stream) { + if (!data?.candidates?.length) break; // Handle case where no candidates are present + + const candidate = data.candidates[0]; + if (!candidate.content?.parts?.length) continue; // Skip if no parts are present + + const firstPart = candidate.content.parts.find((part) => "text" in part) as + | TextPart + | undefined; + if (!firstPart) continue; // Skip if no text part is found + + const content = firstPart.text; + generatedText += content; + + const output: TextGenerationStreamOutput = { + token: { + id: tokenId++, + text: content, + logprob: 0, + special: false, + }, + generated_text: null, + details: null, + }; + yield output; + } + + const output: TextGenerationStreamOutput = { + token: { + id: tokenId++, + text: "", + logprob: 0, + special: true, + }, + generated_text: generatedText, + details: null, + }; + yield output; + })(); + }; +} + +async function fileToImageBlock( + file: MessageFile, + opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp"> +): Promise { + const processor = makeImageProcessor(opts); + const { image, mime } = await processor(file); + + return { + inlineData: { + mimeType: mime, + data: image.toString("base64"), + }, + }; +} + +export default endpointGenAI; diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index f463bca5c97..4e67b46b430 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -230,6 +230,8 @@ const addEndpoint = (m: Awaited>) => ({ return endpoints.ollama(args); case "vertex": return await endpoints.vertex(args); + case "genai": + return await endpoints.genai(args); case "cloudflare": return await endpoints.cloudflare(args); case "cohere":