From 1e08ef02ada789c8ebcd9d0d77b64fea830c3294 Mon Sep 17 00:00:00 2001 From: Chen188 Date: Tue, 13 Aug 2024 14:35:19 +0000 Subject: [PATCH 1/2] * adapting code with TEN * pass transcribe and polly init param when invoking start api; * update transcribe_asr graph to display chat in playground; * other code improvements. --- agents/property.json | 108 +++++++++++++++++- .../bedrock_llm_extension.py | 72 +++++------- .../polly_tts/polly_tts_extension.py | 2 + .../extension/polly_tts/polly_wrapper.py | 42 ++++++- .../transcribe_asr_extension.py | 29 +++-- .../transcribe_wrapper.py | 37 ++++-- playground/src/common/constant.ts | 4 + playground/src/common/graph.ts | 10 ++ server/internal/config.go | 2 +- 9 files changed, 237 insertions(+), 69 deletions(-) diff --git a/agents/property.json b/agents/property.json index 7e461ee3..cad63043 100644 --- a/agents/property.json +++ b/agents/property.json @@ -441,7 +441,7 @@ "region": "us-east-1", "access_key": "$AWS_ACCESS_KEY_ID", "secret_key": "$AWS_SECRET_ACCESS_KEY", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1008,7 +1008,7 @@ "region": "us-east-1", "access_key": "$AWS_ACCESS_KEY_ID", "secret_key": "$AWS_SECRET_ACCESS_KEY", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1036,6 +1036,12 @@ "addon": "interrupt_detector_python", "name": "interrupt_detector" }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "message_collector", + "name": "message_collector" + }, { "type": "extension_group", "addon": "default_extension_group", @@ -1067,6 +1073,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "message_collector" } ] } @@ -1082,6 +1092,30 @@ { "extension_group": "tts", "extension": "polly_tts" + }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -1124,6 +1158,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -1158,7 +1207,7 @@ "remote_stream_id": 123, "subscribe_audio": true, "publish_audio": true, - "publish_data": false, + "publish_data": true, "enable_agora_asr": false, "agora_asr_vendor_name": "microsoft", "agora_asr_language": "en-US", @@ -1189,7 +1238,7 @@ "region": "us-east-1", "access_key": "$AWS_ACCESS_KEY_ID", "secret_key": "$AWS_SECRET_ACCESS_KEY", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1217,6 +1266,12 @@ "addon": "interrupt_detector_python", "name": "interrupt_detector" }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "message_collector", + "name": "message_collector" + }, { "type": "extension_group", "addon": "default_extension_group", @@ -1297,6 +1352,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "message_collector" } ] } @@ -1312,6 +1371,30 @@ { "extension_group": "tts", "extension": "polly_tts" + }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -1354,6 +1437,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -2161,4 +2259,4 @@ } ] } -} +} \ No newline at end of file diff --git a/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py b/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py index 5e7918a8..ecd7d528 100644 --- a/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py +++ b/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py @@ -136,20 +136,27 @@ def on_start(self, ten: TenEnv) -> None: # Send greeting if available if greeting: - try: - output_data = Data.create("text_data") - output_data.set_property_string( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting - ) - output_data.set_property_bool( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True - ) - ten.send_data(output_data) - logger.info(f"greeting [{greeting}] sent") - except Exception as err: - logger.info(f"greeting [{greeting}] send failed, err: {err}") + logger.info(f'sending greeting: [{greeting}]') + self.send_data(ten=ten, sentence=greeting, end_of_segment=True, input_text='greeting') + ten.on_start_done() + def send_data(self, ten, sentence, end_of_segment, input_text): + try: + output_data = Data.create("text_data") + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) + output_data.set_property_bool( + DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment + ) + ten.send_data(output_data) + logger.info( + f"for input text: [{input_text}] {'end of segment ' if end_of_segment else ''}sent sentence [{sentence}]" + ) + except Exception as err: + logger.exception( + f"for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}" + ) + def on_stop(self, ten: TenEnv) -> None: logger.info("BedrockLLMExtension on_stop") ten.on_stop_done() @@ -294,24 +301,12 @@ def converse_stream_worker(start_time, input_text, memory): ) # send sentence - try: - output_data = Data.create("text_data") - output_data.set_property_string( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence - ) - output_data.set_property_bool( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False - ) - ten.send_data(output_data) - logger.info( - f"GetConverseStream recv for input text: [{input_text}] sent sentence [{sentence}]" - ) - except Exception as err: - logger.info( - f"GetConverseStream recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}" - ) - break - + self.send_data( + ten=ten, + sentence=sentence, + end_of_segment=False, + input_text=input_text, + ) sentence = "" if not first_sentence_sent: first_sentence_sent = True @@ -335,22 +330,7 @@ def converse_stream_worker(start_time, input_text, memory): return # send end of segment - try: - output_data = Data.create("text_data") - output_data.set_property_string( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence - ) - output_data.set_property_bool( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True - ) - ten.send_data(output_data) - logger.info( - f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] sent" - ) - except Exception as err: - logger.info( - f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {err}" - ) + self.send_data(ten=ten, sentence=sentence, end_of_segment=True, input_text=input_text) except Exception as e: logger.info( diff --git a/agents/ten_packages/extension/polly_tts/polly_tts_extension.py b/agents/ten_packages/extension/polly_tts/polly_tts_extension.py index b41ad5fa..c7defb4d 100644 --- a/agents/ten_packages/extension/polly_tts/polly_tts_extension.py +++ b/agents/ten_packages/extension/polly_tts/polly_tts_extension.py @@ -63,6 +63,8 @@ def on_start(self, ten: TenEnv) -> None: f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {polly_config.__getattribute__(optional_param)}" ) + polly_config.validate() + self.polly = PollyWrapper(polly_config) self.frame_size = int( int(polly_config.sample_rate) diff --git a/agents/ten_packages/extension/polly_tts/polly_wrapper.py b/agents/ten_packages/extension/polly_tts/polly_wrapper.py index 0f1759bf..987b5b8b 100644 --- a/agents/ten_packages/extension/polly_tts/polly_wrapper.py +++ b/agents/ten_packages/extension/polly_tts/polly_wrapper.py @@ -7,6 +7,29 @@ from .log import logger +ENGINE_STANDARD = 'standard' +ENGINE_NEURAL = 'neural' +ENGINE_GENERATIVE = 'generative' +ENGINE_LONG_FORM = 'long-form' + +VOICE_ENGINE_MAP = { + "Zhiyu": [ENGINE_STANDARD, ENGINE_NEURAL], + "Matthew": [ENGINE_GENERATIVE, ENGINE_NEURAL], + "Ruth": [ENGINE_GENERATIVE, ENGINE_NEURAL, ENGINE_LONG_FORM] +} + +VOICE_LANG_MAP = { + "Zhiyu": ['cmn-CN'], + "Matthew": ['en-US'], + "Ruth": ['en-US'] +} + +LANGCODE_MAP = { + 'cmn-CN': 'cmn-CN', + 'zh-CN': 'cmn-CN', + 'en-US': 'en-US' +} + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/polly/client/synthesize_speech.html class PollyConfig: def __init__(self, @@ -30,6 +53,23 @@ def __init__(self, self.audio_format = 'pcm' # 'json'|'mp3'|'ogg_vorbis'|'pcm' self.include_visemes = False + def validate(self): + if not self.voice in set(VOICE_ENGINE_MAP.keys()): + err_msg = f"Invalid voice '{self.voice}'. Must be one of {list(VOICE_ENGINE_MAP.keys())}." + logger.error(err_msg) + raise ValueError(err_msg) + + if not self.engine in VOICE_ENGINE_MAP[self.voice]: + logger.warn(f"Invalid engine '{self.engine}' for voice '{self.voice}'. Must be one of {VOICE_ENGINE_MAP[self.voice]}. Fallback to {VOICE_ENGINE_MAP[self.voice][0]}") + self.engine = VOICE_ENGINE_MAP[self.voice][0] + + if self.lang_code: + self.lang_code = LANGCODE_MAP.get(self.lang_code, self.lang_code) + + if not self.lang_code in VOICE_LANG_MAP[self.voice]: + logger.warn(f"Invalid language code '{self.lang_code}' for voice '{self.voice}'. Must be one of {VOICE_LANG_MAP[self.voice]}. Fallback to {VOICE_LANG_MAP[self.voice][0]}") + self.lang_code = VOICE_LANG_MAP[self.voice][0] + @classmethod def default_config(cls): return cls( @@ -172,4 +212,4 @@ def get_voices(self, engine, language_code): vo["Name"]: vo["Id"] for vo in self.voice_metadata if engine in vo["SupportedEngines"] and language_code == vo["LanguageCode"] - } \ No newline at end of file + } diff --git a/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py b/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py index fb0307ee..1ebc83ed 100644 --- a/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py +++ b/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py @@ -7,6 +7,7 @@ CmdResult, ) +import json import asyncio import threading @@ -63,24 +64,30 @@ def on_start(self, ten: TenEnv) -> None: ten.on_start_done() - def put_pcm_frame(self, pcm_frame: AudioFrame) -> None: + def put_audio_frame(self, pcm_frame: AudioFrame) -> None: + if self.loop.is_closed(): + logger.warning("Event loop is closed, cannot enqueue frame") + return + try: asyncio.run_coroutine_threadsafe( self.queue.put(pcm_frame), self.loop ).result(timeout=0.1) except asyncio.QueueFull: logger.exception("Queue is full, dropping frame") + except asyncio.TimeoutError: + logger.warning("Timeout while putting frame in queue") except Exception as e: logger.exception(f"Error putting frame in queue: {e}") def on_audio_frame(self, ten: TenEnv, frame: AudioFrame) -> None: - self.put_pcm_frame(pcm_frame=frame) + self.put_audio_frame(pcm_frame=frame) def on_stop(self, ten: TenEnv) -> None: logger.info("TranscribeAsrExtension on_stop") # put an empty frame to stop transcribe_wrapper - self.put_pcm_frame(None) + self.put_audio_frame(None) self.stopped = True self.thread.join() self.loop.stop() @@ -92,10 +99,16 @@ def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: logger.info("TranscribeAsrExtension on_cmd") cmd_json = cmd.to_json() logger.info("TranscribeAsrExtension on_cmd json: " + cmd_json) + try: + cmd_json = json.loads(cmd_json) - cmdName = cmd.get_name() - logger.info("got cmd %s" % cmdName) + cmdName = cmd.get_name() + logger.info("got cmd %s" % cmdName) + if cmdName == "on_user_joined": + self.transcribe.set_user_id(cmd_json.get('user_id', '0'), cmd_json.get('remote_user_id', '0')) - cmd_result = CmdResult.create(StatusCode.OK) - cmd_result.set_property_string("detail", "success") - ten.return_result(cmd_result, cmd) + cmd_result = CmdResult.create(StatusCode.OK) + cmd_result.set_property_string("detail", "success") + ten.return_result(cmd_result, cmd) + except Exception as e: + logger.exception(f"Error handling cmd: {e}") diff --git a/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py b/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py index c6b500ba..f69a7cc0 100644 --- a/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py +++ b/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py @@ -16,13 +16,8 @@ DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" - -def create_and_send_data(ten: TenEnv, text_result: str, is_final: bool): - stable_data = Data.create("text_data") - stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final) - stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result) - ten.send_data(stable_data) - +DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" +DATA_OUT_TEXT_DATA_PROPERTY_EOS = "end_of_segment" class AsyncTranscribeWrapper(): def __init__(self, config: TranscribeConfig, queue: asyncio.Queue, ten:TenEnv, loop: asyncio.BaseEventLoop): @@ -51,6 +46,11 @@ def __init__(self, config: TranscribeConfig, queue: asyncio.Queue, ten:TenEnv, l asyncio.set_event_loop(self.loop) self.reset_stream() + def set_user_id(self, user_id:str="0", remote_user_id:str="0"): + logger.info(f"set_user_id: {user_id}, {remote_user_id}") + self.user_id = user_id + self.remote_user_id = remote_user_id + def reset_stream(self): self.stream = None self.handler = None @@ -71,6 +71,7 @@ async def create_stream(self) -> bool: try: self.stream = await self.get_transcribe_stream() self.handler = TranscribeEventHandler(self.stream.output_stream, self.ten) + self.handler.set_user_id(self.user_id, self.remote_user_id) self.event_handler_task = asyncio.create_task(self.handler.handle_events()) except Exception as e: logger.exception(e) @@ -143,15 +144,20 @@ def __init__(self, transcript_result_stream: TranscriptResultStream, ten: TenEnv super().__init__(transcript_result_stream) self.ten = ten + self.user_id = 0 + self.remote_user_id = 0 + async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> None: results = transcript_event.transcript.results text_result = "" is_final = True + end_of_segment = True for result in results: if result.is_partial: is_final = False + end_of_segment = False # continue for alt in result.alternatives: @@ -162,4 +168,19 @@ async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> No logger.info(f"got transcript: [{text_result}], is_final: [{is_final}]") - create_and_send_data(ten=self.ten, text_result=text_result, is_final=is_final) + self.create_and_send_data(text_result=text_result, is_final=is_final, end_of_segment=end_of_segment) + + def set_user_id(self, user_id:str="0", remote_user_id:str="0"): + self.user_id = int(user_id) + self.remote_user_id = int(remote_user_id) + + def create_and_send_data(self, text_result: str, is_final: bool, end_of_segment: bool): + stable_data = Data.create("text_data") + try: + stable_data.set_property_int(DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID, self.remote_user_id) + stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result) + stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final) + stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_EOS, end_of_segment) + self.ten.send_data(stable_data) + except Exception as e: + logger.exception(e) \ No newline at end of file diff --git a/playground/src/common/constant.ts b/playground/src/common/constant.ts index a2b2db83..bdfb5747 100644 --- a/playground/src/common/constant.ts +++ b/playground/src/common/constant.ts @@ -40,6 +40,10 @@ export const GRAPH_OPTIONS: GraphOptionItem[] = [ label: "Voice Agent with Knowledge - RAG + Qwen LLM + Cosy TTS", value: "va.qwen.rag" }, + { + label: "Voice Agent - Transcribe + Bedrock + Polly", + value: "va.transcribe-bedrock.polly" + }, ] export const isRagGraph = (graphName: string) => { diff --git a/playground/src/common/graph.ts b/playground/src/common/graph.ts index 523ee14b..a302eecb 100644 --- a/playground/src/common/graph.ts +++ b/playground/src/common/graph.ts @@ -101,6 +101,16 @@ export const getGraphProperties = (graphName: string, language: string, voiceTyp "azure_synthesis_voice_name": voiceNameMap[language]["azure"][voiceType] } } + } else if (graphName == "va.transcribe-bedrock.polly") { + return { + "transcribe_asr": { + "lang_code": language, + }, + "polly_tts": { + "voice": voiceNameMap[language]["polly"][voiceType], + "lang_code": language + } + } } return {} } \ No newline at end of file diff --git a/server/internal/config.go b/server/internal/config.go index e0e2edb0..81198223 100644 --- a/server/internal/config.go +++ b/server/internal/config.go @@ -38,4 +38,4 @@ var ( {ExtensionName: extensionNameHttpServer, Property: "listen_port"}, }, } -) +) \ No newline at end of file From fe740c7e4d341c340608793540a26e361c423a08 Mon Sep 17 00:00:00 2001 From: Chen188 Date: Wed, 4 Sep 2024 03:41:49 +0000 Subject: [PATCH 2/2] fix: allow transcribe using CN regions --- .../transcribe_asr_python/transcribe_wrapper.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py b/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py index f69a7cc0..92298b61 100644 --- a/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py +++ b/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py @@ -7,6 +7,7 @@ ) from amazon_transcribe.auth import StaticCredentialResolver +from amazon_transcribe.endpoints import BaseEndpointResolver from amazon_transcribe.client import TranscribeStreamingClient from amazon_transcribe.handlers import TranscriptResultStreamHandler from amazon_transcribe.model import TranscriptEvent, TranscriptResultStream, StartStreamTranscriptionEventStream @@ -19,6 +20,14 @@ DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" DATA_OUT_TEXT_DATA_PROPERTY_EOS = "end_of_segment" +class TranscribeCnEndpointResolver(BaseEndpointResolver): + def __init__(self, region: str): + self.region = region + + async def resolve(self, region: str) -> str: + """Apply region to transcribe uri template.""" + return f"https://transcribestreaming.{region}.amazonaws.com.cn" + class AsyncTranscribeWrapper(): def __init__(self, config: TranscribeConfig, queue: asyncio.Queue, ten:TenEnv, loop: asyncio.BaseEventLoop): self.queue = queue @@ -29,8 +38,14 @@ def __init__(self, config: TranscribeConfig, queue: asyncio.Queue, ten:TenEnv, l if config.access_key and config.secret_key: logger.info(f"init trascribe client with access key: {config.access_key}") + + endpoint_resolver = None + if config.region.startswith("cn-"): + endpoint_resolver = TranscribeCnEndpointResolver(config.region) + self.transcribe_client = TranscribeStreamingClient( region=config.region, + endpoint_resolver=endpoint_resolver, credential_resolver=StaticCredentialResolver( access_key_id=config.access_key, secret_access_key=config.secret_key