forked from huggingface/chat-ui
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Google Gemini API Support (huggingface#1330)
* Supports Google Gemini API * Fixed error when running svelte-check * update: from the contribution of nsarrazin
- Loading branch information
Showing
6 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "@google/generative-ai"; | ||
import type { Content, Part, TextPart } from "@google/generative-ai"; | ||
import { z } from "zod"; | ||
import type { Message, MessageFile } from "$lib/types/Message"; | ||
import type { TextGenerationStreamOutput } from "@huggingface/inference"; | ||
import type { Endpoint } from "../endpoints"; | ||
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; | ||
import type { ImageProcessorOptions } from "../images"; | ||
|
||
export const endpointGenAIParametersSchema = z.object({ | ||
weight: z.number().int().positive().default(1), | ||
model: z.any(), | ||
type: z.literal("genai"), | ||
apiKey: z.string(), | ||
safetyThreshold: z | ||
.enum([ | ||
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, | ||
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, | ||
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | ||
HarmBlockThreshold.BLOCK_NONE, | ||
HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
]) | ||
.optional(), | ||
multimodal: z | ||
.object({ | ||
image: createImageProcessorOptionsValidator({ | ||
supportedMimeTypes: ["image/png", "image/jpeg", "image/webp"], | ||
preferredMimeType: "image/webp", | ||
// The 4 / 3 compensates for the 33% increase in size when converting to base64 | ||
maxSizeInMB: (5 / 4) * 3, | ||
maxWidth: 4096, | ||
maxHeight: 4096, | ||
}), | ||
}) | ||
.default({}), | ||
}); | ||
|
||
export function endpointGenAI(input: z.input<typeof endpointGenAIParametersSchema>): Endpoint { | ||
const { model, apiKey, safetyThreshold, multimodal } = endpointGenAIParametersSchema.parse(input); | ||
|
||
const genAI = new GoogleGenerativeAI(apiKey); | ||
|
||
return async ({ messages, preprompt, generateSettings }) => { | ||
const parameters = { ...model.parameters, ...generateSettings }; | ||
|
||
const generativeModel = genAI.getGenerativeModel({ | ||
model: model.id ?? model.name, | ||
safetySettings: safetyThreshold | ||
? [ | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | ||
threshold: safetyThreshold, | ||
}, | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_HARASSMENT, | ||
threshold: safetyThreshold, | ||
}, | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
threshold: safetyThreshold, | ||
}, | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | ||
threshold: safetyThreshold, | ||
}, | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_UNSPECIFIED, | ||
threshold: safetyThreshold, | ||
}, | ||
] | ||
: undefined, | ||
generationConfig: { | ||
maxOutputTokens: parameters?.max_new_tokens ?? 4096, | ||
stopSequences: parameters?.stop, | ||
temperature: parameters?.temperature ?? 1, | ||
}, | ||
}); | ||
|
||
let systemMessage = preprompt; | ||
if (messages[0].from === "system") { | ||
systemMessage = messages[0].content; | ||
messages.shift(); | ||
} | ||
|
||
const genAIMessages = await Promise.all( | ||
messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => { | ||
return { | ||
role: from === "user" ? "user" : "model", | ||
parts: [ | ||
...(await Promise.all( | ||
(files ?? []).map((file) => fileToImageBlock(file, multimodal.image)) | ||
)), | ||
{ text: content }, | ||
], | ||
}; | ||
}) | ||
); | ||
|
||
const result = await generativeModel.generateContentStream({ | ||
contents: genAIMessages, | ||
systemInstruction: | ||
systemMessage && systemMessage.trim() !== "" | ||
? { | ||
role: "system", | ||
parts: [{ text: systemMessage }], | ||
} | ||
: undefined, | ||
}); | ||
|
||
let tokenId = 0; | ||
return (async function* () { | ||
let generatedText = ""; | ||
|
||
for await (const data of result.stream) { | ||
if (!data?.candidates?.length) break; // Handle case where no candidates are present | ||
|
||
const candidate = data.candidates[0]; | ||
if (!candidate.content?.parts?.length) continue; // Skip if no parts are present | ||
|
||
const firstPart = candidate.content.parts.find((part) => "text" in part) as | ||
| TextPart | ||
| undefined; | ||
if (!firstPart) continue; // Skip if no text part is found | ||
|
||
const content = firstPart.text; | ||
generatedText += content; | ||
|
||
const output: TextGenerationStreamOutput = { | ||
token: { | ||
id: tokenId++, | ||
text: content, | ||
logprob: 0, | ||
special: false, | ||
}, | ||
generated_text: null, | ||
details: null, | ||
}; | ||
yield output; | ||
} | ||
|
||
const output: TextGenerationStreamOutput = { | ||
token: { | ||
id: tokenId++, | ||
text: "", | ||
logprob: 0, | ||
special: true, | ||
}, | ||
generated_text: generatedText, | ||
details: null, | ||
}; | ||
yield output; | ||
})(); | ||
}; | ||
} | ||
|
||
async function fileToImageBlock( | ||
file: MessageFile, | ||
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp"> | ||
): Promise<Part> { | ||
const processor = makeImageProcessor(opts); | ||
const { image, mime } = await processor(file); | ||
|
||
return { | ||
inlineData: { | ||
mimeType: mime, | ||
data: image.toString("base64"), | ||
}, | ||
}; | ||
} | ||
|
||
export default endpointGenAI; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters