diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 39b6d3dea..396889ceb 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -88,6 +88,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ }) .optional(), }); +export type GeminiConfig = z.infer; export const gemini10Pro = modelRef({ name: 'googleai/gemini-1.0-pro', @@ -167,13 +168,54 @@ export const SUPPORTED_V15_MODELS = { 'gemini-1.5-flash-8b': gemini15Flash8b, }; -export const SUPPORTED_GEMINI_MODELS: Record< - string, - ModelReference -> = { +export const GENERIC_GEMINI_MODEL = modelRef({ + name: 'googleai/gemini', + configSchema: GeminiConfigSchema, + info: { + label: 'Google Gemini', + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + }, + }, +}); + +export const SUPPORTED_GEMINI_MODELS = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, -}; +} as const; + +function longestMatchingPrefix(version: string, potentialMatches: string[]) { + return potentialMatches + .filter((p) => version.startsWith(p)) + .reduce( + (longest, current) => + current.length > longest.length ? current : longest, + '' + ); +} +export type GeminiVersionString = + | keyof typeof SUPPORTED_GEMINI_MODELS + | (string & {}); + +export function gemini( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const matchingKey = longestMatchingPrefix( + version, + Object.keys(SUPPORTED_GEMINI_MODELS) + ); + if (matchingKey) { + return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({ + ...options, + version, + }); + } + return GENERIC_GEMINI_MODEL.withConfig({ ...options, version }); +} function toGeminiRole( role: MessageData['role'], @@ -473,7 +515,7 @@ export function defineGoogleAIModel( apiVersion?: string, baseUrl?: string, info?: ModelInfo, - defaultConfig?: z.infer + defaultConfig?: GeminiConfig ): ModelAction { if (!apiKey) { apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY; @@ -624,7 +666,7 @@ export function defineGoogleAIModel( const chatRequest = { systemInstruction, generationConfig, - tools, + tools: tools.length ? tools : undefined, toolConfig, history: messages .slice(0, -1) diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index abb731426..6f1b7e8e6 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -22,14 +22,22 @@ import { textEmbeddingGecko001, } from './embedder.js'; import { + GENERIC_GEMINI_MODEL, SUPPORTED_V15_MODELS, SUPPORTED_V1_MODELS, defineGoogleAIModel, + gemini, gemini10Pro, gemini15Flash, gemini15Pro, } from './gemini.js'; -export { gemini10Pro, gemini15Flash, gemini15Pro, textEmbeddingGecko001 }; +export { + gemini, + gemini10Pro, + gemini15Flash, + gemini15Pro, + textEmbeddingGecko001, +}; export interface PluginOptions { apiKey?: string; @@ -48,6 +56,16 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { apiVersions = [options?.apiVersion]; } } + + defineGoogleAIModel( + ai, + GENERIC_GEMINI_MODEL.name, + options?.apiKey, + undefined, + options?.baseUrl, + GENERIC_GEMINI_MODEL.info + ); + if (apiVersions.includes('v1beta')) { Object.keys(SUPPORTED_V15_MODELS).forEach((name) => defineGoogleAIModel( diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 0c608e157..186491abd 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -1009,6 +1009,12 @@ importers: '@opentelemetry/sdk-trace-base': specifier: ^1.25.0 version: 1.25.1(@opentelemetry/api@1.9.0) + body-parser: + specifier: ^1.20.3 + version: 1.20.3 + express: + specifier: ^4.21.0 + version: 4.21.0 firebase-admin: specifier: '>=12.2' version: 12.3.1(encoding@0.1.13) @@ -1022,6 +1028,9 @@ importers: rimraf: specifier: ^6.0.1 version: 6.0.1 + tsx: + specifier: ^4.19.2 + version: 4.19.2 typescript: specifier: ^5.3.3 version: 5.4.5 diff --git a/js/testapps/flow-simple-ai/package.json b/js/testapps/flow-simple-ai/package.json index fd0f02e9e..6ca06aad8 100644 --- a/js/testapps/flow-simple-ai/package.json +++ b/js/testapps/flow-simple-ai/package.json @@ -4,7 +4,7 @@ "description": "", "main": "lib/index.js", "scripts": { - "start": "node lib/index.js", + "start": "pnpm exec genkit start -- pnpm exec tsx --watch src/index.ts", "compile": "tsc", "build": "pnpm build:clean && pnpm compile", "build:clean": "rimraf ./lib", @@ -15,18 +15,21 @@ "author": "", "license": "ISC", "dependencies": { - "genkit": "workspace:*", "@genkit-ai/firebase": "workspace:*", "@genkit-ai/google-cloud": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", "@google/generative-ai": "^0.15.0", "@opentelemetry/sdk-trace-base": "^1.25.0", + "body-parser": "^1.20.3", + "express": "^4.21.0", "firebase-admin": ">=12.2", + "genkit": "workspace:*", "partial-json": "^0.1.7" }, "devDependencies": { "rimraf": "^6.0.1", + "tsx": "^4.19.2", "typescript": "^5.3.3" } }