Skip to content

Commit

Permalink
* adapting code with TEN
Browse files Browse the repository at this point in the history
* pass transcribe and polly init param when invoking start api;
* update transcribe_asr graph to display chat in playground;
* other code improvements.
  • Loading branch information
Chen188 committed Aug 29, 2024
1 parent a1090a3 commit 1e08ef0
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 69 deletions.
108 changes: 103 additions & 5 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@
"region": "us-east-1",
"access_key": "$AWS_ACCESS_KEY_ID",
"secret_key": "$AWS_SECRET_ACCESS_KEY",
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"model": "$AWS_BEDROCK_MODEL",
"max_tokens": 512,
"prompt": "",
"greeting": "ASTRA agent connected. How can i help you today?",
Expand Down Expand Up @@ -1008,7 +1008,7 @@
"region": "us-east-1",
"access_key": "$AWS_ACCESS_KEY_ID",
"secret_key": "$AWS_SECRET_ACCESS_KEY",
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"model": "$AWS_BEDROCK_MODEL",
"max_tokens": 512,
"prompt": "",
"greeting": "ASTRA agent connected. How can i help you today?",
Expand Down Expand Up @@ -1036,6 +1036,12 @@
"addon": "interrupt_detector_python",
"name": "interrupt_detector"
},
{
"type": "extension",
"extension_group": "transcriber",
"addon": "message_collector",
"name": "message_collector"
},
{
"type": "extension_group",
"addon": "default_extension_group",
Expand Down Expand Up @@ -1067,6 +1073,10 @@
{
"extension_group": "bedrock",
"extension": "bedrock_llm"
},
{
"extension_group": "transcriber",
"extension": "message_collector"
}
]
}
Expand All @@ -1082,6 +1092,30 @@
{
"extension_group": "tts",
"extension": "polly_tts"
},
{
"extension_group": "transcriber",
"extension": "message_collector",
"cmd_conversions": [
{
"cmd": {
"type": "per_property",
"keep_original": true,
"rules": [
{
"path": "is_final",
"type": "fixed_value",
"value": "bool(true)"
},
{
"path": "stream_id",
"type": "fixed_value",
"value": "uint32(999)"
}
]
}
}
]
}
]
}
Expand Down Expand Up @@ -1124,6 +1158,21 @@
}
]
},
{
"extension_group": "transcriber",
"extension": "message_collector",
"data": [
{
"name": "data",
"dest": [
{
"extension_group": "default",
"extension": "agora_rtc"
}
]
}
]
},
{
"extension_group": "default",
"extension": "interrupt_detector",
Expand Down Expand Up @@ -1158,7 +1207,7 @@
"remote_stream_id": 123,
"subscribe_audio": true,
"publish_audio": true,
"publish_data": false,
"publish_data": true,
"enable_agora_asr": false,
"agora_asr_vendor_name": "microsoft",
"agora_asr_language": "en-US",
Expand Down Expand Up @@ -1189,7 +1238,7 @@
"region": "us-east-1",
"access_key": "$AWS_ACCESS_KEY_ID",
"secret_key": "$AWS_SECRET_ACCESS_KEY",
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"model": "$AWS_BEDROCK_MODEL",
"max_tokens": 512,
"prompt": "",
"greeting": "ASTRA agent connected. How can i help you today?",
Expand Down Expand Up @@ -1217,6 +1266,12 @@
"addon": "interrupt_detector_python",
"name": "interrupt_detector"
},
{
"type": "extension",
"extension_group": "transcriber",
"addon": "message_collector",
"name": "message_collector"
},
{
"type": "extension_group",
"addon": "default_extension_group",
Expand Down Expand Up @@ -1297,6 +1352,10 @@
{
"extension_group": "bedrock",
"extension": "bedrock_llm"
},
{
"extension_group": "transcriber",
"extension": "message_collector"
}
]
}
Expand All @@ -1312,6 +1371,30 @@
{
"extension_group": "tts",
"extension": "polly_tts"
},
{
"extension_group": "transcriber",
"extension": "message_collector",
"cmd_conversions": [
{
"cmd": {
"type": "per_property",
"keep_original": true,
"rules": [
{
"path": "is_final",
"type": "fixed_value",
"value": "bool(true)"
},
{
"path": "stream_id",
"type": "fixed_value",
"value": "uint32(999)"
}
]
}
}
]
}
]
}
Expand Down Expand Up @@ -1354,6 +1437,21 @@
}
]
},
{
"extension_group": "transcriber",
"extension": "message_collector",
"data": [
{
"name": "data",
"dest": [
{
"extension_group": "default",
"extension": "agora_rtc"
}
]
}
]
},
{
"extension_group": "default",
"extension": "interrupt_detector",
Expand Down Expand Up @@ -2161,4 +2259,4 @@
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,27 @@ def on_start(self, ten: TenEnv) -> None:

# Send greeting if available
if greeting:
try:
output_data = Data.create("text_data")
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting
)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True
)
ten.send_data(output_data)
logger.info(f"greeting [{greeting}] sent")
except Exception as err:
logger.info(f"greeting [{greeting}] send failed, err: {err}")
logger.info(f'sending greeting: [{greeting}]')
self.send_data(ten=ten, sentence=greeting, end_of_segment=True, input_text='greeting')

ten.on_start_done()

def send_data(self, ten, sentence, end_of_segment, input_text):
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment
)
ten.send_data(output_data)
logger.info(
f"for input text: [{input_text}] {'end of segment ' if end_of_segment else ''}sent sentence [{sentence}]"
)
except Exception as err:
logger.exception(
f"for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}"
)

def on_stop(self, ten: TenEnv) -> None:
logger.info("BedrockLLMExtension on_stop")
ten.on_stop_done()
Expand Down Expand Up @@ -294,24 +301,12 @@ def converse_stream_worker(start_time, input_text, memory):
)

# send sentence
try:
output_data = Data.create("text_data")
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence
)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False
)
ten.send_data(output_data)
logger.info(
f"GetConverseStream recv for input text: [{input_text}] sent sentence [{sentence}]"
)
except Exception as err:
logger.info(
f"GetConverseStream recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {err}"
)
break

self.send_data(
ten=ten,
sentence=sentence,
end_of_segment=False,
input_text=input_text,
)
sentence = ""
if not first_sentence_sent:
first_sentence_sent = True
Expand All @@ -335,22 +330,7 @@ def converse_stream_worker(start_time, input_text, memory):
return

# send end of segment
try:
output_data = Data.create("text_data")
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence
)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True
)
ten.send_data(output_data)
logger.info(
f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] sent"
)
except Exception as err:
logger.info(
f"GetConverseStream for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {err}"
)
self.send_data(ten=ten, sentence=sentence, end_of_segment=True, input_text=input_text)

except Exception as e:
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def on_start(self, ten: TenEnv) -> None:
f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {polly_config.__getattribute__(optional_param)}"
)

polly_config.validate()

self.polly = PollyWrapper(polly_config)
self.frame_size = int(
int(polly_config.sample_rate)
Expand Down
42 changes: 41 additions & 1 deletion agents/ten_packages/extension/polly_tts/polly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@

from .log import logger

ENGINE_STANDARD = 'standard'
ENGINE_NEURAL = 'neural'
ENGINE_GENERATIVE = 'generative'
ENGINE_LONG_FORM = 'long-form'

VOICE_ENGINE_MAP = {
"Zhiyu": [ENGINE_STANDARD, ENGINE_NEURAL],
"Matthew": [ENGINE_GENERATIVE, ENGINE_NEURAL],
"Ruth": [ENGINE_GENERATIVE, ENGINE_NEURAL, ENGINE_LONG_FORM]
}

VOICE_LANG_MAP = {
"Zhiyu": ['cmn-CN'],
"Matthew": ['en-US'],
"Ruth": ['en-US']
}

LANGCODE_MAP = {
'cmn-CN': 'cmn-CN',
'zh-CN': 'cmn-CN',
'en-US': 'en-US'
}

# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/polly/client/synthesize_speech.html
class PollyConfig:
def __init__(self,
Expand All @@ -30,6 +53,23 @@ def __init__(self,
self.audio_format = 'pcm' # 'json'|'mp3'|'ogg_vorbis'|'pcm'
self.include_visemes = False

def validate(self):
if not self.voice in set(VOICE_ENGINE_MAP.keys()):
err_msg = f"Invalid voice '{self.voice}'. Must be one of {list(VOICE_ENGINE_MAP.keys())}."
logger.error(err_msg)
raise ValueError(err_msg)

if not self.engine in VOICE_ENGINE_MAP[self.voice]:
logger.warn(f"Invalid engine '{self.engine}' for voice '{self.voice}'. Must be one of {VOICE_ENGINE_MAP[self.voice]}. Fallback to {VOICE_ENGINE_MAP[self.voice][0]}")
self.engine = VOICE_ENGINE_MAP[self.voice][0]

if self.lang_code:
self.lang_code = LANGCODE_MAP.get(self.lang_code, self.lang_code)

if not self.lang_code in VOICE_LANG_MAP[self.voice]:
logger.warn(f"Invalid language code '{self.lang_code}' for voice '{self.voice}'. Must be one of {VOICE_LANG_MAP[self.voice]}. Fallback to {VOICE_LANG_MAP[self.voice][0]}")
self.lang_code = VOICE_LANG_MAP[self.voice][0]

@classmethod
def default_config(cls):
return cls(
Expand Down Expand Up @@ -172,4 +212,4 @@ def get_voices(self, engine, language_code):
vo["Name"]: vo["Id"]
for vo in self.voice_metadata
if engine in vo["SupportedEngines"] and language_code == vo["LanguageCode"]
}
}
Loading

0 comments on commit 1e08ef0

Please sign in to comment.