From 062b506a35034c05b140702c2e7ad13b5b981117 Mon Sep 17 00:00:00 2001 From: Taiki Maekawa Date: Wed, 18 Dec 2024 14:27:57 +0900 Subject: [PATCH] =?UTF-8?q?Video=20Chat=20=E3=81=AE=E4=BF=AE=E6=AD=A3=20(#?= =?UTF-8?q?781)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/cdk/lambda/utils/models.ts | 11 ++++ packages/cdk/lib/construct/api.ts | 2 +- .../cdk/lib/generative-ai-use-cases-stack.ts | 3 + packages/common/src/model.ts | 14 ++--- packages/types/src/message.d.ts | 6 +- packages/web/src/hooks/useChat.ts | 48 +++++++++++----- packages/web/src/hooks/useFileApi.ts | 57 +++++++++++-------- packages/web/src/pages/ChatPage.tsx | 2 +- 8 files changed, 91 insertions(+), 52 deletions(-) diff --git a/packages/cdk/lambda/utils/models.ts b/packages/cdk/lambda/utils/models.ts index 5aa424ea..aa0bc428 100644 --- a/packages/cdk/lambda/utils/models.ts +++ b/packages/cdk/lambda/utils/models.ts @@ -225,6 +225,17 @@ const createConverseCommandInput = ( }, }, } as ContentBlock.VideoMember); + } else if (extra.type === 'video' && extra.source.type === 's3') { + contentBlocks.push({ + video: { + format: extra.source.mediaType.split('/')[1], + source: { + s3Location: { + uri: extra.source.data, + }, + }, + }, + } as ContentBlock.VideoMember); } }); } diff --git a/packages/cdk/lib/construct/api.ts b/packages/cdk/lib/construct/api.ts index e08e7a31..97e4215b 100644 --- a/packages/cdk/lib/construct/api.ts +++ b/packages/cdk/lib/construct/api.ts @@ -163,7 +163,7 @@ export class Api extends Construct { ], }, }); - fileBucket.grantWrite(predictStreamFunction); + fileBucket.grantReadWrite(predictStreamFunction); predictStreamFunction.grantInvoke(idPool.authenticatedRole); // Prompt Flow Lambda Function の追加 diff --git a/packages/cdk/lib/generative-ai-use-cases-stack.ts b/packages/cdk/lib/generative-ai-use-cases-stack.ts index c36e9e13..0a0978a3 100644 --- a/packages/cdk/lib/generative-ai-use-cases-stack.ts +++ b/packages/cdk/lib/generative-ai-use-cases-stack.ts @@ -306,5 +306,8 @@ export class GenerativeAiUseCasesStack extends Stack { this.userPool = auth.userPool; this.userPoolClient = auth.client; + + this.exportValue(this.userPool.userPoolId); + this.exportValue(this.userPoolClient.userPoolClientId); } } diff --git a/packages/common/src/model.ts b/packages/common/src/model.ts index d1e328c6..8fb77999 100644 --- a/packages/common/src/model.ts +++ b/packages/common/src/model.ts @@ -37,6 +37,13 @@ export const modelFeatureFlags: Record = { // Amazon Titan 'amazon.titan-text-express-v1': MODEL_FEATURE.TEXT_DOC, 'amazon.titan-text-premier-v1:0': MODEL_FEATURE.TEXT_ONLY, + // Amazon Nova + 'amazon.nova-pro-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE_VIDEO, + 'amazon.nova-lite-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE_VIDEO, + 'amazon.nova-micro-v1:0': MODEL_FEATURE.TEXT_ONLY, + 'us.amazon.nova-pro-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE, // S3 Video アップロードが us-east-1 のみ対応のため。 Video を利用したい場合は us-east-1 の amazon.nova-pro-v1:0 で利用できます。(注意: リージョン変更の際 RAG を有効化している場合削除されます) + 'us.amazon.nova-lite-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE, // 同上 + 'us.amazon.nova-micro-v1:0': MODEL_FEATURE.TEXT_ONLY, // Meta 'meta.llama3-8b-instruct-v1:0': MODEL_FEATURE.TEXT_DOC, 'meta.llama3-70b-instruct-v1:0': MODEL_FEATURE.TEXT_DOC, @@ -56,13 +63,6 @@ export const modelFeatureFlags: Record = { // Cohere 'cohere.command-r-v1:0': MODEL_FEATURE.TEXT_DOC, 'cohere.command-r-plus-v1:0': MODEL_FEATURE.TEXT_DOC, - // Amazon Nova - 'amazon.nova-pro-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE_VIDEO, - 'amazon.nova-lite-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE_VIDEO, - 'amazon.nova-micro-v1:0': MODEL_FEATURE.TEXT_ONLY, - 'us.amazon.nova-pro-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE_VIDEO, - 'us.amazon.nova-lite-v1:0': MODEL_FEATURE.TEXT_DOC_IMAGE_VIDEO, - 'us.amazon.nova-micro-v1:0': MODEL_FEATURE.TEXT_ONLY, // Stability AI Image Gen 'stability.stable-diffusion-xl-v1': MODEL_FEATURE.IMAGE_GEN, 'stability.sd3-large-v1:0': MODEL_FEATURE.IMAGE_GEN, diff --git a/packages/types/src/message.d.ts b/packages/types/src/message.d.ts index 571de00e..c51acf5c 100644 --- a/packages/types/src/message.d.ts +++ b/packages/types/src/message.d.ts @@ -44,9 +44,9 @@ export type ExtraData = { type: string; // 'image' | 'file' name: string; source: { - type: string; // 'S3' - mediaType: string; // file type - data: string; + type: string; // 'S3' | 'base64' + mediaType: string; // mime type (i.e. image/png) + data: string; // s3 location for s3, data for base64 }; }; diff --git a/packages/web/src/hooks/useChat.ts b/packages/web/src/hooks/useChat.ts index 01edb4a6..afe1d6c1 100644 --- a/packages/web/src/hooks/useChat.ts +++ b/packages/web/src/hooks/useChat.ts @@ -24,6 +24,7 @@ import useChatList from './useChatList'; // mutateListChat の本来の型は InfiniteKeyedMutator import { getPrompter } from '../prompts'; import { findModelByModelId } from './useModel'; +import useFileApi from './useFileApi'; const useChatState = create<{ chats: { @@ -83,6 +84,7 @@ const useChatState = create<{ predictStream, predictTitle, } = useChatApi(); + const { getS3Uri } = useFileApi(); const getModelId = (id: string) => { return get().modelIds[id] || ''; @@ -241,21 +243,37 @@ const useChatState = create<{ // LLM で推論する形式に extraData を変換する const extraData: ExtraData[] | undefined = m.extraData?.flatMap( (data) => { - // 推論する際は"data:image/png..." のといった情報は必要ないため、削除する - const base64EncodedData = uploadedFiles - ?.find((uploadedFile) => uploadedFile.s3Url === data.source.data) - ?.base64EncodedData?.replace(/^data:(.*,)?/, ''); - - // Base64 エンコードされた画像情報を設定する - return { - type: data.type, - name: data.name, - source: { - type: 'base64', - mediaType: data.source.mediaType, - data: base64EncodedData ?? '', - }, - }; + if (data.type === 'video') { + // Send S3 location for video + // https:// 形式の S3 URL から s3:// 形式の S3 URI に変換する + const s3Uri = getS3Uri(data.source.data ?? ''); + return { + type: data.type, + name: data.name, + source: { + type: 's3', + mediaType: data.source.mediaType, + data: s3Uri, + }, + }; + } else { + // Otherwise (image and file) send base64 encoded data + // 推論する際は"data:image/png..." のといった情報は必要ないため、削除する + const base64EncodedData = uploadedFiles + ?.find((uploadedFile) => uploadedFile.s3Url === data.source.data) + ?.base64EncodedData?.replace(/^data:(.*,)?/, ''); + + // Base64 エンコードされた画像情報を設定する + return { + type: data.type, + name: data.name, + source: { + type: 'base64', + mediaType: data.source.mediaType, + data: base64EncodedData ?? '', + }, + }; + } } ); return { diff --git a/packages/web/src/hooks/useFileApi.ts b/packages/web/src/hooks/useFileApi.ts index da832462..ca2ae315 100644 --- a/packages/web/src/hooks/useFileApi.ts +++ b/packages/web/src/hooks/useFileApi.ts @@ -11,6 +11,29 @@ import axios from 'axios'; const useFileApi = () => { const http = useHttp(); + const parseS3Url = (s3Url: string) => { + let result = /^s3:\/\/(?.+?)\/(?.+)/.exec(s3Url); + + if (!result) { + result = + /^https:\/\/s3.(?.+?).amazonaws.com\/(?.+?)\/(?.+)$/.exec( + s3Url + ); + + if (!result) { + result = + /^https:\/\/(?.+?).s3(|(\.|-)(?.+?)).amazonaws.com\/(?.+)$/.exec( + s3Url + ); + } + } + + return result?.groups as { + bucketName: string; + prefix: string; + region?: string; + }; + }; return { getSignedUrl: (req: GetFileUploadSignedUrlRequest) => { return http.post('file/url', req); @@ -23,36 +46,16 @@ const useFileApi = () => { data: req.file, }); }, - getFileDownloadSignedUrl: async (s3Uri: string) => { - let result = /^s3:\/\/(?.+?)\/(?.+)/.exec(s3Uri); + getFileDownloadSignedUrl: async (s3Url: string) => { + const { bucketName, prefix, region } = parseS3Url(s3Url); - if (!result) { - result = - /^https:\/\/s3.(?.+?).amazonaws.com\/(?.+?)\/(?.+)$/.exec( - s3Uri - ); - - if (!result) { - result = - /^https:\/\/(?.+?).s3(|(\.|-)(?.+?)).amazonaws.com\/(?.+)$/.exec( - s3Uri - ); - } - } - - const groups = result?.groups as { - bucketName: string; - prefix: string; - region?: string; - }; - - const [filePrefix, anchorLink] = groups.prefix.split('#'); + const [filePrefix, anchorLink] = prefix.split('#'); // Signed URL を取得 const params: GetFileDownloadSignedUrlRequest = { - bucketName: groups.bucketName, + bucketName: bucketName, filePrefix: decodeURIComponent(filePrefix), - region: groups.region, + region: region, }; const { data: url } = await http.api.get('/file/url', { @@ -63,6 +66,10 @@ const useFileApi = () => { deleteUploadedFile: async (fileName: string) => { return http.delete(`file/${fileName}`); }, + getS3Uri: (s3Url: string) => { + const { bucketName, prefix } = parseS3Url(s3Url); + return `s3://${bucketName}/${prefix}`; + }, }; }; diff --git a/packages/web/src/pages/ChatPage.tsx b/packages/web/src/pages/ChatPage.tsx index c613438d..d2baef31 100644 --- a/packages/web/src/pages/ChatPage.tsx +++ b/packages/web/src/pages/ChatPage.tsx @@ -48,7 +48,7 @@ const fileLimit: FileLimit = { maxImageFileCount: 20, maxImageFileSizeMB: 3.75, maxVideoFileCount: 1, - maxVideoFileSizeMB: 25, // 25 MB for base64 input (TODO: up to 1 GB through S3) + maxVideoFileSizeMB: 1000, // 1 GB for S3 input }; type StateType = {