Skip to content

Commit

Permalink
Add Google Gemini API Support (huggingface#1330)
Browse files Browse the repository at this point in the history
* Supports Google Gemini API

* Fixed error when running svelte-check

* update: from the contribution of nsarrazin
  • Loading branch information
toandev95 authored Jul 18, 2024
1 parent 992fe43 commit 9505b86
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 0 deletions.
36 changes: 36 additions & 0 deletions docs/source/configuration/models/providers/google.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,42 @@ MODELS=`[
}
}]
}]
}
]`
```

## GenAI

Or use the Gemini API API provider [from](https://github.com/google-gemini/generative-ai-js#readme):

> Make sure that you have an API key from Google Cloud Platform. To get an API key, follow the instructions [here](https://cloud.google.com/docs/authentication/api-keys).
```ini
MODELS=`[
{
"name": "gemini-1.5-flash",
"displayName": "Gemini Flash 1.5",
"multimodal": true,
"endpoints": [
{
"type": "genai",
"apiKey": "abc...xyz"
}
]

// Optional
"safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"name": "gemini-1.5-pro",
"displayName": "Gemini Pro 1.5",
"multimodal": false,
"endpoints": [
{
"type": "genai",
"apiKey": "abc...xyz"
}
]
}
]`
```
10 changes: 10 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
"@anthropic-ai/sdk": "^0.17.1",
"@anthropic-ai/vertex-sdk": "^0.3.0",
"@google-cloud/vertexai": "^1.1.0",
"@google/generative-ai": "^0.14.1",
"aws4fetch": "^1.0.17",
"cohere-ai": "^7.9.0",
"openai": "^4.44.0"
Expand Down
3 changes: 3 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai";
import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";

import {
endpointAnthropic,
Expand Down Expand Up @@ -65,6 +66,7 @@ export const endpoints = {
llamacpp: endpointLlamacpp,
ollama: endpointOllama,
vertex: endpointVertex,
genai: endpointGenAI,
cloudflare: endpointCloudflare,
cohere: endpointCohere,
langserve: endpointLangserve,
Expand All @@ -79,6 +81,7 @@ export const endpointSchema = z.discriminatedUnion("type", [
endpointLlamacppParametersSchema,
endpointOllamaParametersSchema,
endpointVertexParametersSchema,
endpointGenAIParametersSchema,
endpointCloudflareParametersSchema,
endpointCohereParametersSchema,
endpointLangserveParametersSchema,
Expand Down
171 changes: 171 additions & 0 deletions src/lib/server/endpoints/google/endpointGenAI.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
import type { Content, Part, TextPart } from "@google/generative-ai";
import { z } from "zod";
import type { Message, MessageFile } from "$lib/types/Message";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
import type { ImageProcessorOptions } from "../images";

export const endpointGenAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("genai"),
apiKey: z.string(),
safetyThreshold: z
.enum([
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmBlockThreshold.BLOCK_NONE,
HarmBlockThreshold.BLOCK_ONLY_HIGH,
])
.optional(),
multimodal: z
.object({
image: createImageProcessorOptionsValidator({
supportedMimeTypes: ["image/png", "image/jpeg", "image/webp"],
preferredMimeType: "image/webp",
// The 4 / 3 compensates for the 33% increase in size when converting to base64
maxSizeInMB: (5 / 4) * 3,
maxWidth: 4096,
maxHeight: 4096,
}),
})
.default({}),
});

export function endpointGenAI(input: z.input<typeof endpointGenAIParametersSchema>): Endpoint {
const { model, apiKey, safetyThreshold, multimodal } = endpointGenAIParametersSchema.parse(input);

const genAI = new GoogleGenerativeAI(apiKey);

return async ({ messages, preprompt, generateSettings }) => {
const parameters = { ...model.parameters, ...generateSettings };

const generativeModel = genAI.getGenerativeModel({
model: model.id ?? model.name,
safetySettings: safetyThreshold
? [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold,
},
{
category: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
threshold: safetyThreshold,
},
]
: undefined,
generationConfig: {
maxOutputTokens: parameters?.max_new_tokens ?? 4096,
stopSequences: parameters?.stop,
temperature: parameters?.temperature ?? 1,
},
});

let systemMessage = preprompt;
if (messages[0].from === "system") {
systemMessage = messages[0].content;
messages.shift();
}

const genAIMessages = await Promise.all(
messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => {
return {
role: from === "user" ? "user" : "model",
parts: [
...(await Promise.all(
(files ?? []).map((file) => fileToImageBlock(file, multimodal.image))
)),
{ text: content },
],
};
})
);

const result = await generativeModel.generateContentStream({
contents: genAIMessages,
systemInstruction:
systemMessage && systemMessage.trim() !== ""
? {
role: "system",
parts: [{ text: systemMessage }],
}
: undefined,
});

let tokenId = 0;
return (async function* () {
let generatedText = "";

for await (const data of result.stream) {
if (!data?.candidates?.length) break; // Handle case where no candidates are present

const candidate = data.candidates[0];
if (!candidate.content?.parts?.length) continue; // Skip if no parts are present

const firstPart = candidate.content.parts.find((part) => "text" in part) as
| TextPart
| undefined;
if (!firstPart) continue; // Skip if no text part is found

const content = firstPart.text;
generatedText += content;

const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: content,
logprob: 0,
special: false,
},
generated_text: null,
details: null,
};
yield output;
}

const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: generatedText,
details: null,
};
yield output;
})();
};
}

async function fileToImageBlock(
file: MessageFile,
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
): Promise<Part> {
const processor = makeImageProcessor(opts);
const { image, mime } = await processor(file);

return {
inlineData: {
mimeType: mime,
data: image.toString("base64"),
},
};
}

export default endpointGenAI;
2 changes: 2 additions & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
return endpoints.ollama(args);
case "vertex":
return await endpoints.vertex(args);
case "genai":
return await endpoints.genai(args);
case "cloudflare":
return await endpoints.cloudflare(args);
case "cohere":
Expand Down

0 comments on commit 9505b86

Please sign in to comment.