Skip to content

Commit

Permalink
feat(js/plugins/vertexai): add context caching
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Nov 18, 2024
1 parent 4b1499b commit 7f44d4c
Show file tree
Hide file tree
Showing 9 changed files with 545 additions and 58 deletions.
1 change: 1 addition & 0 deletions js/plugins/vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
},
"devDependencies": {
"@types/node": "^20.11.16",
"google-gax": "^4.4.1",
"npm-run-all": "^4.1.5",
"rimraf": "^6.0.1",
"tsup": "^8.3.5",
Expand Down
29 changes: 29 additions & 0 deletions js/plugins/vertexai/src/context_caching/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

export const CONTEXT_CACHE_SUPPORTED_MODELS = [
'gemini-1.5-flash-001',
'gemini-1.5-pro-001',
];

export const INVALID_ARGUMENT_MESSAGES = {
modelVersion: `Model version is required for context caching, supported only in ${CONTEXT_CACHE_SUPPORTED_MODELS.join(',')} models.`,
tools: 'Context caching cannot be used simultaneously with tools.',
codeExecution:
'Context caching cannot be used simultaneously with code execution.',
};

export const DEFAULT_TTL = 300;
97 changes: 97 additions & 0 deletions js/plugins/vertexai/src/context_caching/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

import { CachedContent, StartChatParams } from '@google-cloud/vertexai';
import {
ApiClient,
CachedContents,
} from '@google-cloud/vertexai/build/src/resources';
import { GenerateRequest, GenkitError, z } from 'genkit';
import { logger } from 'genkit/logging';
import type { CacheConfigDetails } from './types.js';
import {
calculateTTL,
generateCacheKey,
getContentForCache,
lookupContextCache,
} from './utils.js';

/**
* Handles context caching and transforms the chatRequest for Vertex AI.
* @param apiKey
* @param request
* @param chatRequest
* @param modelVersion
* @returns
*/
export async function handleContextCache(
apiClient: ApiClient,
request: GenerateRequest<z.ZodTypeAny>,
chatRequest: StartChatParams,
modelVersion: string,
cacheConfigDetails: CacheConfigDetails
): Promise<{ cache: CachedContent; newChatRequest: StartChatParams }> {
const cachedContentsClient = new CachedContents(apiClient);

const { cachedContent, chatRequest: newChatRequest } = getContentForCache(
request,
chatRequest,
modelVersion,
cacheConfigDetails
);
cachedContent.model = modelVersion;
const cacheKey = generateCacheKey(cachedContent);

cachedContent.displayName = cacheKey;

let cache;
try {
cache = await lookupContextCache(cachedContentsClient, cacheKey);
logger.debug(`Cache hit: ${cache ? 'true' : 'false'}`);
} catch (error) {
logger.debug('No cache found, creating one.');
}

if (!cache) {
try {
const createParams: CachedContent = {
...cachedContent,
// TODO: make this neater - idk why they chose to stringify the ttl...
ttl: JSON.stringify(calculateTTL(cacheConfigDetails)) + 's',
};
cache = await cachedContentsClient.create(createParams);
logger.debug(`Created new cache entry with key: ${cacheKey}`);
} catch (cacheError) {
logger.error(
`Failed to create cache with key ${cacheKey}: ${cacheError}`
);
throw new GenkitError({
status: 'INTERNAL',
message: `Failed to create cache: ${cacheError}`,
});
}
}

if (!cache) {
throw new GenkitError({
status: 'INTERNAL',
message: 'Failed to use context cache feature',
});
}
// This isn't necessary, but it's nice to have for debugging purposes.
newChatRequest.cachedContent = cache.name;

return { cache, newChatRequest };
}
31 changes: 31 additions & 0 deletions js/plugins/vertexai/src/context_caching/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { z } from 'genkit';

export const cacheConfigSchema = z.union([
z.boolean(),
z.object({ ttlSeconds: z.number().optional() }).passthrough(),
]);

export type CacheConfig = z.infer<typeof cacheConfigSchema>;

export const cacheConfigDetailsSchema = z.object({
cacheConfig: cacheConfigSchema,
endOfCachedContents: z.number(),
});

export type CacheConfigDetails = z.infer<typeof cacheConfigDetailsSchema>;
Loading

0 comments on commit 7f44d4c

Please sign in to comment.