diff --git a/js/plugins/vertexai/package.json b/js/plugins/vertexai/package.json index 8495bbc0f..71e84dd96 100644 --- a/js/plugins/vertexai/package.json +++ b/js/plugins/vertexai/package.json @@ -53,6 +53,7 @@ }, "devDependencies": { "@types/node": "^20.11.16", + "google-gax": "^4.4.1", "npm-run-all": "^4.1.5", "rimraf": "^6.0.1", "tsup": "^8.3.5", diff --git a/js/plugins/vertexai/src/context_caching/constants.ts b/js/plugins/vertexai/src/context_caching/constants.ts new file mode 100644 index 000000000..315076a85 --- /dev/null +++ b/js/plugins/vertexai/src/context_caching/constants.ts @@ -0,0 +1,29 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export const CONTEXT_CACHE_SUPPORTED_MODELS = [ + 'gemini-1.5-flash-001', + 'gemini-1.5-pro-001', +]; + +export const INVALID_ARGUMENT_MESSAGES = { + modelVersion: `Model version is required for context caching, supported only in ${CONTEXT_CACHE_SUPPORTED_MODELS.join(',')} models.`, + tools: 'Context caching cannot be used simultaneously with tools.', + codeExecution: + 'Context caching cannot be used simultaneously with code execution.', +}; + +export const DEFAULT_TTL = 300; diff --git a/js/plugins/vertexai/src/context_caching/index.ts b/js/plugins/vertexai/src/context_caching/index.ts new file mode 100644 index 000000000..3b0732e6d --- /dev/null +++ b/js/plugins/vertexai/src/context_caching/index.ts @@ -0,0 +1,97 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +import { CachedContent, StartChatParams } from '@google-cloud/vertexai'; +import { + ApiClient, + CachedContents, +} from '@google-cloud/vertexai/build/src/resources'; +import { GenerateRequest, GenkitError, z } from 'genkit'; +import { logger } from 'genkit/logging'; +import type { CacheConfigDetails } from './types.js'; +import { + calculateTTL, + generateCacheKey, + getContentForCache, + lookupContextCache, +} from './utils.js'; + +/** + * Handles context caching and transforms the chatRequest for Vertex AI. + * @param apiKey + * @param request + * @param chatRequest + * @param modelVersion + * @returns + */ +export async function handleContextCache( + apiClient: ApiClient, + request: GenerateRequest, + chatRequest: StartChatParams, + modelVersion: string, + cacheConfigDetails: CacheConfigDetails +): Promise<{ cache: CachedContent; newChatRequest: StartChatParams }> { + const cachedContentsClient = new CachedContents(apiClient); + + const { cachedContent, chatRequest: newChatRequest } = getContentForCache( + request, + chatRequest, + modelVersion, + cacheConfigDetails + ); + cachedContent.model = modelVersion; + const cacheKey = generateCacheKey(cachedContent); + + cachedContent.displayName = cacheKey; + + let cache; + try { + cache = await lookupContextCache(cachedContentsClient, cacheKey); + logger.debug(`Cache hit: ${cache ? 'true' : 'false'}`); + } catch (error) { + logger.debug('No cache found, creating one.'); + } + + if (!cache) { + try { + const createParams: CachedContent = { + ...cachedContent, + // TODO: make this neater - idk why they chose to stringify the ttl... + ttl: JSON.stringify(calculateTTL(cacheConfigDetails)) + 's', + }; + cache = await cachedContentsClient.create(createParams); + logger.debug(`Created new cache entry with key: ${cacheKey}`); + } catch (cacheError) { + logger.error( + `Failed to create cache with key ${cacheKey}: ${cacheError}` + ); + throw new GenkitError({ + status: 'INTERNAL', + message: `Failed to create cache: ${cacheError}`, + }); + } + } + + if (!cache) { + throw new GenkitError({ + status: 'INTERNAL', + message: 'Failed to use context cache feature', + }); + } + // This isn't necessary, but it's nice to have for debugging purposes. + newChatRequest.cachedContent = cache.name; + + return { cache, newChatRequest }; +} diff --git a/js/plugins/vertexai/src/context_caching/types.ts b/js/plugins/vertexai/src/context_caching/types.ts new file mode 100644 index 000000000..50f186424 --- /dev/null +++ b/js/plugins/vertexai/src/context_caching/types.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from 'genkit'; + +export const cacheConfigSchema = z.union([ + z.boolean(), + z.object({ ttlSeconds: z.number().optional() }).passthrough(), +]); + +export type CacheConfig = z.infer; + +export const cacheConfigDetailsSchema = z.object({ + cacheConfig: cacheConfigSchema, + endOfCachedContents: z.number(), +}); + +export type CacheConfigDetails = z.infer; diff --git a/js/plugins/vertexai/src/context_caching/utils.ts b/js/plugins/vertexai/src/context_caching/utils.ts new file mode 100644 index 000000000..a2834c753 --- /dev/null +++ b/js/plugins/vertexai/src/context_caching/utils.ts @@ -0,0 +1,241 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { CachedContent, StartChatParams } from '@google-cloud/vertexai'; +import { CachedContents } from '@google-cloud/vertexai/build/src/resources'; +import crypto from 'crypto'; +import { GenkitError, MessageData, z } from 'genkit'; +import { logger } from 'genkit/logging'; +import { GenerateRequest } from 'genkit/model'; +import { + CONTEXT_CACHE_SUPPORTED_MODELS, + DEFAULT_TTL, + INVALID_ARGUMENT_MESSAGES, +} from './constants'; +import { CacheConfig, CacheConfigDetails, cacheConfigSchema } from './types'; + +/** + * Generates a SHA-256 hash to use as a cache key. + * @param request CachedContent - request object to hash + * @returns string - the generated cache key + */ +export function generateCacheKey(request: CachedContent): string { + return crypto + .createHash('sha256') + .update(JSON.stringify(request)) + .digest('hex'); +} + +/** + * Retrieves the content needed for the cache based on the chat history and config details. + */ +export function getContentForCache( + request: GenerateRequest, + chatRequest: StartChatParams, + modelVersion: string, + cacheConfigDetails: CacheConfigDetails +): { + cachedContent: CachedContent; + chatRequest: StartChatParams; + cacheConfig?: CacheConfig; +} { + if (!chatRequest.history?.length) { + throw new Error('No history provided for context caching'); + } + + validateHistoryLength(request, chatRequest); + + const { endOfCachedContents, cacheConfig } = cacheConfigDetails; + const cachedContent: CachedContent = { + model: modelVersion, + contents: chatRequest.history.slice(0, endOfCachedContents + 1), + }; + chatRequest.history = chatRequest.history.slice(endOfCachedContents + 1); + + return { cachedContent, chatRequest, cacheConfig }; +} + +/** + * Validates that the request and chat request history lengths align. + * @throws GenkitError if lengths are mismatched + */ +function validateHistoryLength( + request: GenerateRequest, + chatRequest: StartChatParams +) { + if (chatRequest.history?.length !== request.messages.length - 1) { + throw new GenkitError({ + status: 'INTERNAL', + message: + 'Genkit request history and Gemini chat request history length do not match', + }); + } +} + +/** + * Looks up context cache using a cache manager and returns the found item, if any. + */ +export async function lookupContextCache( + cacheManager: CachedContents, + cacheKey: string, + maxPages = 100, + pageSize = 100 +) { + let currentPage = 0; + let pageToken: string | undefined; + + while (currentPage < maxPages) { + const { cachedContents, nextPageToken } = await cacheManager.list( + pageSize, + pageToken + ); + const found = cachedContents?.find( + (content) => content.displayName === cacheKey + ); + + if (found) return found; + if (!nextPageToken) break; + + pageToken = nextPageToken; + currentPage++; + } + return null; +} + +/** + * Clears all caches using the cache manager. + */ +export async function clearAllCaches( + cacheManager: CachedContents, + maxPages = 100, + pageSize = 100 +): Promise { + let currentPage = 0; + let pageToken: string | undefined; + let totalDeleted = 0; + + while (currentPage < maxPages) { + try { + const { cachedContents, nextPageToken } = await cacheManager.list( + pageSize, + pageToken + ); + totalDeleted += await deleteCachedContents(cacheManager, cachedContents); + + if (!nextPageToken) break; + pageToken = nextPageToken; + currentPage++; + } catch (error) { + throw new GenkitError({ + status: 'INTERNAL', + message: `Error clearing caches on page ${currentPage + 1}: ${error}`, + }); + } + } + logger.info(`Total caches deleted: ${totalDeleted}`); +} + +/** + * Helper to delete cached contents and return the number of deletions. + */ +async function deleteCachedContents( + cacheManager: CachedContents, + cachedContents: CachedContent[] = [] +): Promise { + for (const content of cachedContents) { + if (content.name) await cacheManager.delete(content.name); + } + return cachedContents.length; +} + +/** + * Extracts the cache configuration from the request if available. + */ +export const extractCacheConfig = ( + request: GenerateRequest +): { + cacheConfig: { ttlSeconds?: number } | boolean; + endOfCachedContents: number; +} | null => { + const endOfCachedContents = findLastIndex( + request.messages, + (message) => !!message.metadata?.cache + ); + + return endOfCachedContents === -1 + ? null + : { + endOfCachedContents, + cacheConfig: cacheConfigSchema.parse( + request.messages[endOfCachedContents].metadata?.cache + ), + }; +}; + +/** + * Validates context caching request for compatibility with model and request configurations. + */ +export function validateContextCacheRequest( + request: any, + modelVersion: string +): boolean { + if (!modelVersion || !CONTEXT_CACHE_SUPPORTED_MODELS.includes(modelVersion)) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: INVALID_ARGUMENT_MESSAGES.modelVersion, + }); + } + if (request.tools?.length) + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: INVALID_ARGUMENT_MESSAGES.tools, + }); + if (request.config?.codeExecution) + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: INVALID_ARGUMENT_MESSAGES.codeExecution, + }); + + return true; +} + +/** + * Polyfill function for Array.prototype.findLastIndex for ES2015 compatibility. + */ +function findLastIndex( + array: T[], + callback: (element: T, index: number, array: T[]) => boolean +): number { + for (let i = array.length - 1; i >= 0; i--) { + if (callback(array[i], i, array)) return i; + } + return -1; +} + +/** + * Calculates the TTL (Time-To-Live) for the cache based on cacheConfigDetails. + * @param cacheConfig - The caching configuration details. + * @returns The TTL in seconds. + */ +export function calculateTTL(cacheConfig: CacheConfigDetails): number { + if (cacheConfig.cacheConfig === true) { + return DEFAULT_TTL; + } + if (cacheConfig.cacheConfig === false) { + return 0; + } + return cacheConfig.cacheConfig.ttlSeconds || DEFAULT_TTL; +} diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 97d106bba..524b2ccbd 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -15,6 +15,7 @@ */ import { + CachedContent, Content, FunctionCallingMode, FunctionDeclaration, @@ -23,12 +24,14 @@ import { GenerateContentCandidate, GenerateContentResponse, GenerateContentResult, + GenerativeModelPreview, HarmBlockThreshold, HarmCategory, StartChatParams, ToolConfig, VertexAI, } from '@google-cloud/vertexai'; +import { ApiClient } from '@google-cloud/vertexai/build/src/resources/index.js'; import { GENKIT_CLIENT_HEADER, Genkit, JSONSchema, z } from 'genkit'; import { CandidateData, @@ -48,7 +51,13 @@ import { downloadRequestMedia, simulateSystemPrompt, } from 'genkit/model/middleware'; +import { GoogleAuth } from 'google-auth-library'; import { PluginOptions } from './common/types.js'; +import { handleContextCache } from './context_caching/index.js'; +import { + extractCacheConfig, + validateContextCacheRequest, +} from './context_caching/utils.js'; const SafetySettingsSchema = z.object({ category: z.nativeEnum(HarmCategory), @@ -463,14 +472,6 @@ export function defineGeminiModel( }, async (request, streamingCallback) => { const vertex = vertexClientFactory(request); - const client = vertex.preview.getGenerativeModel( - { - model: request.config?.version || model.version || name, - }, - { - apiClient: GENKIT_CLIENT_HEADER, - } - ); // make a copy so that modifying the request will not produce side-effects const messages = [...request.messages]; @@ -510,7 +511,7 @@ export function defineGeminiModel( (request.output?.format === 'json' || !!request.output?.schema) && tools.length === 0; - const chatRequest: StartChatParams = { + let chatRequest: StartChatParams = { systemInstruction, tools, toolConfig, @@ -529,6 +530,39 @@ export function defineGeminiModel( safetySettings: request.config?.safetySettings, }; + let cache: CachedContent | null = null; + + // TODO: fix casting + const modelVersion = (request.config?.version || + model.version || + name) as string; + + const cacheConfigDetails = extractCacheConfig(request); + + if ( + cacheConfigDetails && + validateContextCacheRequest(request, modelVersion) + ) { + const apiClient = new ApiClient( + options.projectId!, + options.location, + 'v1beta1', + new GoogleAuth(options.googleAuth!) + ); + + const handleContextCacheResponse = await handleContextCache( + apiClient, + request, + chatRequest, + modelVersion, + cacheConfigDetails + ); + chatRequest = handleContextCacheResponse.newChatRequest; + cache = handleContextCacheResponse.cache; + } + + let genModel: GenerativeModelPreview | null = null; + if (jsonMode && request.output?.constrained) { chatRequest.generationConfig!.responseSchema = cleanSchema( request.output.schema @@ -559,8 +593,30 @@ export function defineGeminiModel( }); } const msg = toGeminiMessage(messages[messages.length - 1], model); + + if (cache) { + genModel = vertex.preview.getGenerativeModelFromCachedContent( + cache, + { + model: request.config?.version || model.version || name, + }, + { + apiClient: GENKIT_CLIENT_HEADER, + } + ); + } else { + genModel = vertex.preview.getGenerativeModel( + { + model: request.config?.version || model.version || name, + }, + { + apiClient: GENKIT_CLIENT_HEADER, + } + ); + } + if (streamingCallback) { - const result = await client + const result = await genModel .startChat(chatRequest) .sendMessageStream(msg.parts); for await (const item of result.stream) { @@ -585,7 +641,7 @@ export function defineGeminiModel( } else { let result: GenerateContentResult | undefined; try { - result = await client.startChat(chatRequest).sendMessage(msg.parts); + result = await genModel.startChat(chatRequest).sendMessage(msg.parts); } catch (err) { throw new Error(`Vertex response generation failed: ${err}`); } diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 701c1b0ef..699c3921d 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -688,6 +688,9 @@ importers: '@types/node': specifier: ^20.11.16 version: 20.11.30 + google-gax: + specifier: ^4.4.1 + version: 4.4.1(encoding@0.1.13) npm-run-all: specifier: ^4.1.5 version: 4.1.5 @@ -848,6 +851,9 @@ importers: '@genkit-ai/googleai': specifier: workspace:* version: link:../../plugins/googleai + '@genkit-ai/vertexai': + specifier: workspace:* + version: link:../../plugins/vertexai '@google/generative-ai': specifier: ^0.21.0 version: 0.21.0 @@ -2158,15 +2164,6 @@ packages: resolution: {integrity: sha512-HPa/K5NX6ahMoeBv15njAc/sfF4/jmiXLar9UlC2UfHFKZzsCVLc3wbe7+7qua7w9VPh2/L6EBxyAV7/E8Wftg==} engines: {node: '>=12.10.0'} - '@grpc/grpc-js@1.10.4': - resolution: {integrity: sha512-MqBisuxTkYvPFnEiu+dag3xG/NBUDzSbAFAWlzfkGnQkjVZ6by3h4atbBc+Ikqup1z5BfB4BN18gKWR1YyppNw==} - engines: {node: '>=12.10.0'} - - '@grpc/proto-loader@0.7.12': - resolution: {integrity: sha512-DCVwMxqYzpUCiDMl7hQ384FqP4T3DbNpXU8pt681l3UWCip1WUiD5JrkImUwCB9a7f2cq4CUTmi5r/xIMRPY1Q==} - engines: {node: '>=6'} - hasBin: true - '@grpc/proto-loader@0.7.13': resolution: {integrity: sha512-AiXO/bfe9bmxBjxxtYxFAXGZvMaN5s8kO+jBHAJCON8rJoB5YS/D6X7ZNc6XQkuHNmyl4CYaMI1fJ/Gn27RGGw==} engines: {node: '>=6'} @@ -4250,14 +4247,14 @@ packages: resolution: {integrity: sha512-I/AvzBiUXDzLOy4iIZ2W+Zq33W4lcukQv1nl7C8WUA6SQwyQwUwu3waNmWNAvzds//FG8SZ+DnKnW/2k6mQS8A==} engines: {node: '>=14'} - google-gax@4.3.2: - resolution: {integrity: sha512-2mw7qgei2LPdtGrmd1zvxQviOcduTnsvAWYzCxhOWXK4IQKmQztHnDQwD0ApB690fBQJemFKSU7DnceAy3RLzw==} - engines: {node: '>=14'} - google-gax@4.3.7: resolution: {integrity: sha512-3bnD8RASQyaxOYTdWLgwpQco/aytTxFavoI/UN5QN5txDLp8QRrBHNtCUJ5+Ago+551GD92jG8jJduwvmaneUw==} engines: {node: '>=14'} + google-gax@4.4.1: + resolution: {integrity: sha512-Phyp9fMfA00J3sZbJxbbB4jC55b7DBjE3F6poyL3wKMEBVKA79q6BGuHcTiM28yOzVql0NDbRL8MLLh8Iwk9Dg==} + engines: {node: '>=14'} + google-p12-pem@4.0.1: resolution: {integrity: sha512-WPkN4yGtz05WZ5EhtlxNDWPhC4JIic6G8ePitwUWy4l+XPVYec+a0j0Ts47PDtW59y3RwAhUd9/h9ZZ63px6RQ==} engines: {node: '>=12.0.0'} @@ -5481,10 +5478,6 @@ packages: resolution: {integrity: sha512-NxNv/kLguCA7p3jE8oL2aEBsrJWgAakBpgmgK6lpPWV+WuOmY6r2/zbAVnP+T8bQlA0nzHXSJSJW0Hq7ylaD2Q==} engines: {node: '>= 6'} - proto3-json-serializer@2.0.1: - resolution: {integrity: sha512-8awBvjO+FwkMd6gNoGFZyqkHZXCFd54CIYTb6De7dPaufGJ2XNW+QUNqbMr8MaAocMdb+KpsD4rxEOaTBDCffA==} - engines: {node: '>=14.0.0'} - proto3-json-serializer@2.0.2: resolution: {integrity: sha512-SAzp/O4Yh02jGdRc+uIrGoe87dkN/XtwxfZ4ZyafJHymd79ozp5VG5nyZ7ygqPM5+cpLDjjGnYFUkngonyDPOQ==} engines: {node: '>=14.0.0'} @@ -6708,7 +6701,7 @@ snapshots: '@google-cloud/aiplatform@3.25.0(encoding@0.1.13)': dependencies: - google-gax: 4.3.7(encoding@0.1.13) + google-gax: 4.4.1(encoding@0.1.13) protobuf.js: 1.1.2 transitivePeerDependencies: - encoding @@ -6750,7 +6743,7 @@ snapshots: dependencies: fast-deep-equal: 3.1.3 functional-red-black-tree: 1.0.1 - google-gax: 4.3.2(encoding@0.1.13) + google-gax: 4.3.7(encoding@0.1.13) protobufjs: 7.2.6 transitivePeerDependencies: - encoding @@ -6788,8 +6781,13 @@ snapshots: eventid: 2.0.1 extend: 3.0.2 gcp-metadata: 6.1.0(encoding@0.1.13) +<<<<<<< HEAD google-auth-library: 9.14.2(encoding@0.1.13) google-gax: 4.3.2(encoding@0.1.13) +======= + google-auth-library: 9.11.0(encoding@0.1.13) + google-gax: 4.3.7(encoding@0.1.13) +>>>>>>> 87b01855 (feat(js/plugins/vertexai): add context caching) on-finished: 2.4.1 pumpify: 2.0.1 stream-events: 1.0.5 @@ -6889,18 +6887,6 @@ snapshots: '@grpc/proto-loader': 0.7.13 '@js-sdsl/ordered-map': 4.4.2 - '@grpc/grpc-js@1.10.4': - dependencies: - '@grpc/proto-loader': 0.7.12 - '@js-sdsl/ordered-map': 4.4.2 - - '@grpc/proto-loader@0.7.12': - dependencies: - lodash.camelcase: 4.3.0 - long: 5.2.3 - protobufjs: 7.2.6 - yargs: 17.7.2 - '@grpc/proto-loader@0.7.13': dependencies: lodash.camelcase: 4.3.0 @@ -9254,25 +9240,25 @@ snapshots: - encoding - supports-color - google-gax@4.3.2(encoding@0.1.13): + google-gax@4.3.7(encoding@0.1.13): dependencies: - '@grpc/grpc-js': 1.10.4 - '@grpc/proto-loader': 0.7.12 + '@grpc/grpc-js': 1.10.10 + '@grpc/proto-loader': 0.7.13 '@types/long': 4.0.2 abort-controller: 3.0.0 duplexify: 4.1.3 google-auth-library: 9.14.2(encoding@0.1.13) node-fetch: 2.7.0(encoding@0.1.13) object-hash: 3.0.0 - proto3-json-serializer: 2.0.1 - protobufjs: 7.2.6 + proto3-json-serializer: 2.0.2 + protobufjs: 7.3.2 retry-request: 7.0.2(encoding@0.1.13) uuid: 9.0.1 transitivePeerDependencies: - encoding - supports-color - google-gax@4.3.7(encoding@0.1.13): + google-gax@4.4.1(encoding@0.1.13): dependencies: '@grpc/grpc-js': 1.10.10 '@grpc/proto-loader': 0.7.13 @@ -10663,10 +10649,6 @@ snapshots: kleur: 3.0.3 sisteransi: 1.0.5 - proto3-json-serializer@2.0.1: - dependencies: - protobufjs: 7.2.6 - proto3-json-serializer@2.0.2: dependencies: protobufjs: 7.3.2 diff --git a/js/testapps/context-caching/package.json b/js/testapps/context-caching/package.json index aa6c0a927..61c564ed2 100644 --- a/js/testapps/context-caching/package.json +++ b/js/testapps/context-caching/package.json @@ -16,6 +16,7 @@ "license": "ISC", "dependencies": { "@genkit-ai/googleai": "workspace:*", + "@genkit-ai/vertexai": "workspace:*", "@google/generative-ai": "^0.21.0", "genkit": "workspace:*" }, diff --git a/js/testapps/context-caching/src/index.ts b/js/testapps/context-caching/src/index.ts index bbce908c9..af5ab6f6f 100644 --- a/js/testapps/context-caching/src/index.ts +++ b/js/testapps/context-caching/src/index.ts @@ -14,20 +14,24 @@ * limitations under the License. */ -import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; // Import specific AI plugins/models +import { + gemini15Flash as gemini15FlashGoogleAI, + googleAI, +} from '@genkit-ai/googleai'; +import { gemini15Flash, vertexAI } from '@genkit-ai/vertexai'; // Import specific AI plugins/models import * as fs from 'fs/promises'; // Import fs module to handle file operations asynchronously import { genkit, z } from 'genkit'; // Import Genkit framework and Zod for schema validation import { logger } from 'genkit/logging'; // Import logging utility from Genkit const ai = genkit({ - plugins: [googleAI()], // Initialize Genkit with the Google AI plugin + plugins: [vertexAI(), googleAI()], // Initialize Genkit with the Google AI plugin }); logger.setLogLevel('debug'); // Set the logging level to debug for detailed output -export const lotrFlow = ai.defineFlow( +export const lotrFlowVertex = ai.defineFlow( { - name: 'lotrFlow', // Define a unique name for this flow + name: 'lotrFlowVertex', // Define a unique name for this flow inputSchema: z.object({ query: z.string().optional(), // Define a query input, which is optional textFilePath: z.string(), // Add the file path to input schema @@ -35,8 +39,7 @@ export const lotrFlow = ai.defineFlow( outputSchema: z.string(), // Define the expected output as a string }, async ({ query, textFilePath }) => { - const defaultQuery = - "Describe Gandalf's relationship with Frodo, referencing Gandalf quotes from the text."; // Default query to use if none is provided + const defaultQuery = 'What is the text i provided you with?'; // Default query to use if none is provided // Read the content from the file if the path is provided const textContent = await fs.readFile(textFilePath, 'utf-8'); // Read the file as UTF-8 encoded text @@ -71,3 +74,49 @@ export const lotrFlow = ai.defineFlow( return llmResponse.text; // Return the generated text from the model } ); + +export const lotrFlowGoogleAI = ai.defineFlow( + { + name: 'lotrFlowGoogleAI', // Define a unique name for this flow + inputSchema: z.object({ + query: z.string().optional(), // Define a query input, which is optional + textFilePath: z.string(), // Add the file path to input schema + }), + outputSchema: z.string(), // Define the expected output as a string + }, + async ({ query, textFilePath }) => { + const defaultQuery = 'What is the text i provided you with?'; // Default query to use if none is provided + + // Read the content from the file if the path is provided + const textContent = await fs.readFile(textFilePath, 'utf-8'); // Read the file as UTF-8 encoded text + + const llmResponse = await ai.generate({ + messages: [ + { + role: 'user', // Represents the user's input or query + content: [{ text: textContent }], // Use the loaded file content here + }, + { + role: 'model', // Represents the model's response + content: [ + { + text: 'This is the first few chapters of Lord of the Rings. Can I help in any way?', // Example model response + }, + ], + metadata: { + cache: { + ttlSeconds: 300, // Set the cache time-to-live for this message to 300 seconds + }, // this message is the last one to be cached. + }, + }, + ], + config: { + version: 'gemini-1.5-flash-001', // Specify the version of the model to be used + }, + model: gemini15FlashGoogleAI, // Specify the model (gemini15Flash) to use for generation + prompt: query || defaultQuery, // Use the provided query or fall back to the default query + }); + + return llmResponse.text; // Return the generated text from the model + } +);