From 3b0394e6fd75a1efc49bd8f4ac3c3bbe88d4d57b Mon Sep 17 00:00:00 2001 From: Chen188 Date: Tue, 13 Aug 2024 14:35:19 +0000 Subject: [PATCH] * adapting code with TEN * pass transcribe and polly init param when invoking start api; * update transcribe_asr graph to display chat in playground; * other code improvements. --- agents/property.json | 108 +++++++++- agents/property.json.example | 194 ++++++++++++++++-- .../bedrock_llm_extension.py | 72 +++---- .../polly_tts/polly_tts_extension.py | 2 + .../extension/polly_tts/polly_wrapper.py | 42 +++- .../transcribe_asr_extension.py | 29 ++- .../transcribe_wrapper.py | 37 +++- playground/src/common/constant.ts | 4 + playground/src/common/graph.ts | 10 + server/internal/config.go | 2 +- 10 files changed, 410 insertions(+), 90 deletions(-) diff --git a/agents/property.json b/agents/property.json index 05b8ba86..d2b78a4c 100644 --- a/agents/property.json +++ b/agents/property.json @@ -441,7 +441,7 @@ "region": "us-east-1", "access_key": "$AWS_ACCESS_KEY_ID", "secret_key": "$AWS_SECRET_ACCESS_KEY", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1008,7 +1008,7 @@ "region": "us-east-1", "access_key": "$AWS_ACCESS_KEY_ID", "secret_key": "$AWS_SECRET_ACCESS_KEY", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1036,6 +1036,12 @@ "addon": "interrupt_detector_python", "name": "interrupt_detector" }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "message_collector", + "name": "message_collector" + }, { "type": "extension_group", "addon": "default_extension_group", @@ -1067,6 +1073,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "message_collector" } ] } @@ -1082,6 +1092,30 @@ { "extension_group": "tts", "extension": "polly_tts" + }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -1124,6 +1158,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -1158,7 +1207,7 @@ "remote_stream_id": 123, "subscribe_audio": true, "publish_audio": true, - "publish_data": false, + "publish_data": true, "enable_agora_asr": false, "agora_asr_vendor_name": "microsoft", "agora_asr_language": "en-US", @@ -1189,7 +1238,7 @@ "region": "us-east-1", "access_key": "$AWS_ACCESS_KEY_ID", "secret_key": "$AWS_SECRET_ACCESS_KEY", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1217,6 +1266,12 @@ "addon": "interrupt_detector_python", "name": "interrupt_detector" }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "message_collector", + "name": "message_collector" + }, { "type": "extension_group", "addon": "default_extension_group", @@ -1297,6 +1352,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "message_collector" } ] } @@ -1312,6 +1371,30 @@ { "extension_group": "tts", "extension": "polly_tts" + }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -1354,6 +1437,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -2375,4 +2473,4 @@ } ] } -} +} \ No newline at end of file diff --git a/agents/property.json.example b/agents/property.json.example index d6245847..90fdf8a3 100644 --- a/agents/property.json.example +++ b/agents/property.json.example @@ -499,6 +499,12 @@ "azure_synthesis_voice_name": "en-US-JaneNeural" } }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "chat_transcriber", + "name": "chat_transcriber" + }, { "type": "extension", "extension_group": "default", @@ -519,6 +525,11 @@ "type": "extension_group", "addon": "default_extension_group", "name": "tts" + }, + { + "type": "extension_group", + "addon": "default_extension_group", + "name": "transcriber" } ], "connections": [ @@ -536,6 +547,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "chat_transcriber" } ] } @@ -551,6 +566,30 @@ { "extension_group": "tts", "extension": "azure_tts" + }, + { + "extension_group": "transcriber", + "extension": "chat_transcriber", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -593,6 +632,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "chat_transcriber", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -1063,7 +1117,7 @@ "addon": "agora_rtc", "name": "agora_rtc", "property": { - "app_id": "", + "app_id": "$AGORA_APP_ID", "token": "", "channel": "astra_agents_test", "stream_id": 1234, @@ -1074,8 +1128,8 @@ "enable_agora_asr": true, "agora_asr_vendor_name": "microsoft", "agora_asr_language": "en-US", - "agora_asr_vendor_key": "", - "agora_asr_vendor_region": "", + "agora_asr_vendor_key": "$AZURE_STT_KEY", + "agora_asr_vendor_region": "$AZURE_STT_REGION", "agora_asr_session_control_file_path": "session_control.conf" } }, @@ -1086,9 +1140,9 @@ "name": "bedrock_llm", "property": { "region": "us-east-1", - "access_key": "", - "secret_key": "", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "access_key": "$AWS_ACCESS_KEY_ID", + "secret_key": "$AWS_SECRET_ACCESS_KEY", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1102,8 +1156,8 @@ "name": "polly_tts", "property": { "region": "us-east-1", - "access_key": "", - "secret_key": "", + "access_key": "$AWS_ACCESS_KEY_ID", + "secret_key": "$AWS_SECRET_ACCESS_KEY", "engine": "generative", "voice": "Ruth", "sample_rate": "16000", @@ -1116,6 +1170,12 @@ "addon": "interrupt_detector_python", "name": "interrupt_detector" }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "message_collector", + "name": "message_collector" + }, { "type": "extension_group", "addon": "default_extension_group", @@ -1147,6 +1207,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "message_collector" } ] } @@ -1162,6 +1226,30 @@ { "extension_group": "tts", "extension": "polly_tts" + }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -1204,6 +1292,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -1231,19 +1334,19 @@ "addon": "agora_rtc", "name": "agora_rtc", "property": { - "app_id": "baf537f77ebc4187a06a1628a1827f14", + "app_id": "$AGORA_APP_ID", "token": "", "channel": "astra_agents_test", "stream_id": 1234, "remote_stream_id": 123, "subscribe_audio": true, "publish_audio": true, - "publish_data": false, + "publish_data": true, "enable_agora_asr": false, "agora_asr_vendor_name": "microsoft", "agora_asr_language": "en-US", - "agora_asr_vendor_key": "", - "agora_asr_vendor_region": "", + "agora_asr_vendor_key": "$AZURE_STT_KEY", + "agora_asr_vendor_region": "$AZURE_STT_REGION", "agora_asr_session_control_file_path": "session_control.conf" } }, @@ -1254,8 +1357,8 @@ "name": "transcribe_asr", "property": { "region": "us-east-1", - "access_key": "", - "secret_key": "", + "access_key": "$AWS_ACCESS_KEY_ID", + "secret_key": "$AWS_SECRET_ACCESS_KEY", "sample_rate": "16000", "lang_code": "en-US" } @@ -1267,9 +1370,9 @@ "name": "bedrock_llm", "property": { "region": "us-east-1", - "access_key": "", - "secret_key": "", - "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "access_key": "$AWS_ACCESS_KEY_ID", + "secret_key": "$AWS_SECRET_ACCESS_KEY", + "model": "$AWS_BEDROCK_MODEL", "max_tokens": 512, "prompt": "", "greeting": "ASTRA agent connected. How can i help you today?", @@ -1283,8 +1386,8 @@ "name": "polly_tts", "property": { "region": "us-east-1", - "access_key": "", - "secret_key": "", + "access_key": "$AWS_ACCESS_KEY_ID", + "secret_key": "$AWS_SECRET_ACCESS_KEY", "engine": "generative", "voice": "Ruth", "sample_rate": "16000", @@ -1297,6 +1400,12 @@ "addon": "interrupt_detector_python", "name": "interrupt_detector" }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "message_collector", + "name": "message_collector" + }, { "type": "extension_group", "addon": "default_extension_group", @@ -1377,6 +1486,10 @@ { "extension_group": "bedrock", "extension": "bedrock_llm" + }, + { + "extension_group": "transcriber", + "extension": "message_collector" } ] } @@ -1392,6 +1505,30 @@ { "extension_group": "tts", "extension": "polly_tts" + }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] } ] } @@ -1434,6 +1571,21 @@ } ] }, + { + "extension_group": "transcriber", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, { "extension_group": "default", "extension": "interrupt_detector", @@ -1499,7 +1651,7 @@ "prompt": "", "proxy_url": "", "greeting": "Astra已连接,需要我为您提供什么帮助?", - "checking_vision_text_items":"[\"让我看看你的摄像头...\",\"让我看一下...\",\"我看一下,请稍候...\"]", + "checking_vision_text_items": "[\"让我看看你的摄像头...\",\"让我看一下...\",\"我看一下,请稍候...\"]", "max_memory_length": 10, "enable_tools": true } @@ -1733,7 +1885,7 @@ "prompt": "", "proxy_url": "", "greeting": "ASTRA agent connected. How can i help you today?", - "checking_vision_text_items":"[\"Let me take a look...\",\"Let me check your camera...\",\"Please wait for a second...\"]", + "checking_vision_text_items": "[\"Let me take a look...\",\"Let me check your camera...\",\"Please wait for a second...\"]", "max_memory_length": 10, "enable_tools": true } diff --git a/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py b/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py index 5e7918a8..ecd7d528 100644 --- a/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py +++ b/agents/ten_packages/extension/bedrock_llm_python/bedrock_llm_extension.py @@ -136,20 +136,27 @@ def on_start(self, ten: TenEnv) -> None: # Send greeting if available if greeting: - try: - output_data = Data.create("text_data") - output_data.set_property_string( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting - ) - output_data.set_property_bool( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True - ) - ten.send_data(output_data) - logger.info(f"greeting [{greeting}] sent") - except Exception as err: - logger.info(f"greeting [{greeting}] send failed, err: {err}") + logger.info(f'sending greeting: [{greeting}]') + self.send_data(ten=ten, sentence=greeting, end_of_segment=True, input_text='greeting') + ten.on_start_done() + def send_data(self, ten, sentence, end_of_segment, input_text): + try: + output_data = Data.create("text_data") + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) + output_data.set_property_bool( + DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment + ) + ten.send_data(output_data) + logger.info( + f"for input text: [{input_text}] {'end of segment ' if end_of_segment else ''}sent sentence [{sentence}]" + ) + except Exception as err: + logger.exception( + f"for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}" + ) + def on_stop(self, ten: TenEnv) -> None: logger.info("BedrockLLMExtension on_stop") ten.on_stop_done() @@ -294,24 +301,12 @@ def converse_stream_worker(start_time, input_text, memory): ) # send sentence - try: - output_data = Data.create("text_data") - output_data.set_property_string( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence - ) - output_data.set_property_bool( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False - ) - ten.send_data(output_data) - logger.info( - f"GetConverseStream recv for input text: [{input_text}] sent sentence [{sentence}]" - ) - except Exception as err: - logger.info( - f"GetConverseStream recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}" - ) - break - + self.send_data( + ten=ten, + sentence=sentence, + end_of_segment=False, + input_text=input_text, + ) sentence = "" if not first_sentence_sent: first_sentence_sent = True @@ -335,22 +330,7 @@ def converse_stream_worker(start_time, input_text, memory): return # send end of segment - try: - output_data = Data.create("text_data") - output_data.set_property_string( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence - ) - output_data.set_property_bool( - DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True - ) - ten.send_data(output_data) - logger.info( - f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] sent" - ) - except Exception as err: - logger.info( - f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {err}" - ) + self.send_data(ten=ten, sentence=sentence, end_of_segment=True, input_text=input_text) except Exception as e: logger.info( diff --git a/agents/ten_packages/extension/polly_tts/polly_tts_extension.py b/agents/ten_packages/extension/polly_tts/polly_tts_extension.py index b41ad5fa..c7defb4d 100644 --- a/agents/ten_packages/extension/polly_tts/polly_tts_extension.py +++ b/agents/ten_packages/extension/polly_tts/polly_tts_extension.py @@ -63,6 +63,8 @@ def on_start(self, ten: TenEnv) -> None: f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {polly_config.__getattribute__(optional_param)}" ) + polly_config.validate() + self.polly = PollyWrapper(polly_config) self.frame_size = int( int(polly_config.sample_rate) diff --git a/agents/ten_packages/extension/polly_tts/polly_wrapper.py b/agents/ten_packages/extension/polly_tts/polly_wrapper.py index 0f1759bf..987b5b8b 100644 --- a/agents/ten_packages/extension/polly_tts/polly_wrapper.py +++ b/agents/ten_packages/extension/polly_tts/polly_wrapper.py @@ -7,6 +7,29 @@ from .log import logger +ENGINE_STANDARD = 'standard' +ENGINE_NEURAL = 'neural' +ENGINE_GENERATIVE = 'generative' +ENGINE_LONG_FORM = 'long-form' + +VOICE_ENGINE_MAP = { + "Zhiyu": [ENGINE_STANDARD, ENGINE_NEURAL], + "Matthew": [ENGINE_GENERATIVE, ENGINE_NEURAL], + "Ruth": [ENGINE_GENERATIVE, ENGINE_NEURAL, ENGINE_LONG_FORM] +} + +VOICE_LANG_MAP = { + "Zhiyu": ['cmn-CN'], + "Matthew": ['en-US'], + "Ruth": ['en-US'] +} + +LANGCODE_MAP = { + 'cmn-CN': 'cmn-CN', + 'zh-CN': 'cmn-CN', + 'en-US': 'en-US' +} + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/polly/client/synthesize_speech.html class PollyConfig: def __init__(self, @@ -30,6 +53,23 @@ def __init__(self, self.audio_format = 'pcm' # 'json'|'mp3'|'ogg_vorbis'|'pcm' self.include_visemes = False + def validate(self): + if not self.voice in set(VOICE_ENGINE_MAP.keys()): + err_msg = f"Invalid voice '{self.voice}'. Must be one of {list(VOICE_ENGINE_MAP.keys())}." + logger.error(err_msg) + raise ValueError(err_msg) + + if not self.engine in VOICE_ENGINE_MAP[self.voice]: + logger.warn(f"Invalid engine '{self.engine}' for voice '{self.voice}'. Must be one of {VOICE_ENGINE_MAP[self.voice]}. Fallback to {VOICE_ENGINE_MAP[self.voice][0]}") + self.engine = VOICE_ENGINE_MAP[self.voice][0] + + if self.lang_code: + self.lang_code = LANGCODE_MAP.get(self.lang_code, self.lang_code) + + if not self.lang_code in VOICE_LANG_MAP[self.voice]: + logger.warn(f"Invalid language code '{self.lang_code}' for voice '{self.voice}'. Must be one of {VOICE_LANG_MAP[self.voice]}. Fallback to {VOICE_LANG_MAP[self.voice][0]}") + self.lang_code = VOICE_LANG_MAP[self.voice][0] + @classmethod def default_config(cls): return cls( @@ -172,4 +212,4 @@ def get_voices(self, engine, language_code): vo["Name"]: vo["Id"] for vo in self.voice_metadata if engine in vo["SupportedEngines"] and language_code == vo["LanguageCode"] - } \ No newline at end of file + } diff --git a/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py b/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py index fb0307ee..1ebc83ed 100644 --- a/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py +++ b/agents/ten_packages/extension/transcribe_asr_python/transcribe_asr_extension.py @@ -7,6 +7,7 @@ CmdResult, ) +import json import asyncio import threading @@ -63,24 +64,30 @@ def on_start(self, ten: TenEnv) -> None: ten.on_start_done() - def put_pcm_frame(self, pcm_frame: AudioFrame) -> None: + def put_audio_frame(self, pcm_frame: AudioFrame) -> None: + if self.loop.is_closed(): + logger.warning("Event loop is closed, cannot enqueue frame") + return + try: asyncio.run_coroutine_threadsafe( self.queue.put(pcm_frame), self.loop ).result(timeout=0.1) except asyncio.QueueFull: logger.exception("Queue is full, dropping frame") + except asyncio.TimeoutError: + logger.warning("Timeout while putting frame in queue") except Exception as e: logger.exception(f"Error putting frame in queue: {e}") def on_audio_frame(self, ten: TenEnv, frame: AudioFrame) -> None: - self.put_pcm_frame(pcm_frame=frame) + self.put_audio_frame(pcm_frame=frame) def on_stop(self, ten: TenEnv) -> None: logger.info("TranscribeAsrExtension on_stop") # put an empty frame to stop transcribe_wrapper - self.put_pcm_frame(None) + self.put_audio_frame(None) self.stopped = True self.thread.join() self.loop.stop() @@ -92,10 +99,16 @@ def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: logger.info("TranscribeAsrExtension on_cmd") cmd_json = cmd.to_json() logger.info("TranscribeAsrExtension on_cmd json: " + cmd_json) + try: + cmd_json = json.loads(cmd_json) - cmdName = cmd.get_name() - logger.info("got cmd %s" % cmdName) + cmdName = cmd.get_name() + logger.info("got cmd %s" % cmdName) + if cmdName == "on_user_joined": + self.transcribe.set_user_id(cmd_json.get('user_id', '0'), cmd_json.get('remote_user_id', '0')) - cmd_result = CmdResult.create(StatusCode.OK) - cmd_result.set_property_string("detail", "success") - ten.return_result(cmd_result, cmd) + cmd_result = CmdResult.create(StatusCode.OK) + cmd_result.set_property_string("detail", "success") + ten.return_result(cmd_result, cmd) + except Exception as e: + logger.exception(f"Error handling cmd: {e}") diff --git a/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py b/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py index c6b500ba..f69a7cc0 100644 --- a/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py +++ b/agents/ten_packages/extension/transcribe_asr_python/transcribe_wrapper.py @@ -16,13 +16,8 @@ DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" - -def create_and_send_data(ten: TenEnv, text_result: str, is_final: bool): - stable_data = Data.create("text_data") - stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final) - stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result) - ten.send_data(stable_data) - +DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" +DATA_OUT_TEXT_DATA_PROPERTY_EOS = "end_of_segment" class AsyncTranscribeWrapper(): def __init__(self, config: TranscribeConfig, queue: asyncio.Queue, ten:TenEnv, loop: asyncio.BaseEventLoop): @@ -51,6 +46,11 @@ def __init__(self, config: TranscribeConfig, queue: asyncio.Queue, ten:TenEnv, l asyncio.set_event_loop(self.loop) self.reset_stream() + def set_user_id(self, user_id:str="0", remote_user_id:str="0"): + logger.info(f"set_user_id: {user_id}, {remote_user_id}") + self.user_id = user_id + self.remote_user_id = remote_user_id + def reset_stream(self): self.stream = None self.handler = None @@ -71,6 +71,7 @@ async def create_stream(self) -> bool: try: self.stream = await self.get_transcribe_stream() self.handler = TranscribeEventHandler(self.stream.output_stream, self.ten) + self.handler.set_user_id(self.user_id, self.remote_user_id) self.event_handler_task = asyncio.create_task(self.handler.handle_events()) except Exception as e: logger.exception(e) @@ -143,15 +144,20 @@ def __init__(self, transcript_result_stream: TranscriptResultStream, ten: TenEnv super().__init__(transcript_result_stream) self.ten = ten + self.user_id = 0 + self.remote_user_id = 0 + async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> None: results = transcript_event.transcript.results text_result = "" is_final = True + end_of_segment = True for result in results: if result.is_partial: is_final = False + end_of_segment = False # continue for alt in result.alternatives: @@ -162,4 +168,19 @@ async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> No logger.info(f"got transcript: [{text_result}], is_final: [{is_final}]") - create_and_send_data(ten=self.ten, text_result=text_result, is_final=is_final) + self.create_and_send_data(text_result=text_result, is_final=is_final, end_of_segment=end_of_segment) + + def set_user_id(self, user_id:str="0", remote_user_id:str="0"): + self.user_id = int(user_id) + self.remote_user_id = int(remote_user_id) + + def create_and_send_data(self, text_result: str, is_final: bool, end_of_segment: bool): + stable_data = Data.create("text_data") + try: + stable_data.set_property_int(DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID, self.remote_user_id) + stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result) + stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final) + stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_EOS, end_of_segment) + self.ten.send_data(stable_data) + except Exception as e: + logger.exception(e) \ No newline at end of file diff --git a/playground/src/common/constant.ts b/playground/src/common/constant.ts index 2d8f809b..fc7199aa 100644 --- a/playground/src/common/constant.ts +++ b/playground/src/common/constant.ts @@ -32,6 +32,10 @@ export const GRAPH_OPTIONS: GraphOptionItem[] = [ label: "Voice Agent with Knowledge - RAG + Qwen LLM + Cosy TTS", value: "va.qwen.rag" }, + { + label: "Voice Agent - Transcribe + Bedrock + Polly", + value: "va.transcribe-bedrock.polly" + }, ] export const isRagGraph = (graphName: string) => { diff --git a/playground/src/common/graph.ts b/playground/src/common/graph.ts index 9540412f..3b64ea91 100644 --- a/playground/src/common/graph.ts +++ b/playground/src/common/graph.ts @@ -41,6 +41,16 @@ export const getGraphProperties = (graphName: string, language: string, voiceTyp "azure_synthesis_voice_name": voiceNameMap[language]["azure"][voiceType] } } + } else if (graphName == "va.transcribe-bedrock.polly") { + return { + "transcribe_asr": { + "lang_code": language, + }, + "polly_tts": { + "voice": voiceNameMap[language]["polly"][voiceType], + "lang_code": language + } + } } return {} } \ No newline at end of file diff --git a/server/internal/config.go b/server/internal/config.go index 2701f212..2d754d99 100644 --- a/server/internal/config.go +++ b/server/internal/config.go @@ -53,4 +53,4 @@ var ( {ExtensionName: extensionNameHttpServer, Property: "listen_port"}, }, } -) +) \ No newline at end of file