Skip to content

Commit

Permalink
dictionaries again
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Aug 25, 2024
1 parent e2327a9 commit 590f686
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 20 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ This can be overridden by setting `BABYLON_BUILD_SOURCE` to `ON`.
#include "babylon.h"

int main() {
babylon_g2p_init("path/to/deep_phonemizer.onnx", "en_us", 1);
babylon_g2p_options_t options = {
.language = "en_us",
.use_dictionaries = 1,
.use_punctuation = 1,
};

babylon_g2p_init("path/to/deep_phonemizer.onnx", options);

const char* text = "Hello World";

Expand Down
12 changes: 9 additions & 3 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
cmake_minimum_required(VERSION 3.18)

project(babylon_example)
project(example_cpp)

add_executable(babylon_example main.cpp)
add_executable(example_cpp main.cpp)

target_link_libraries(babylon_example babylon)
target_link_libraries(example_cpp babylon)

project(example_c)

add_executable(example_c main.c)

target_link_libraries(example_c babylon)
25 changes: 25 additions & 0 deletions example/main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "babylon.h"

int main(int argc, char** argv) {
if (argc < 2) {
return 1;
}

babylon_g2p_options_t options = {
.language = "en_us",
.use_dictionaries = 1,
.use_punctuation = 1
};

babylon_g2p_init("./models/deep_phonemizer.onnx", options);

babylon_tts_init("./models/curie.onnx");

babylon_tts(argv[1], "path/to/output.wav");

babylon_tts_free();

babylon_g2p_free();

return 0;
}
2 changes: 1 addition & 1 deletion example/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ int main(int argc, char** argv) {

text = argv[1];

DeepPhonemizer::Session dp(dp_model_path, "en_us", true);
DeepPhonemizer::Session dp(dp_model_path, "en_us", true, true);

Vits::Session vits(vits_model_path);

Expand Down
8 changes: 7 additions & 1 deletion include/babylon.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ extern "C" {
#define BABYLON_EXPORT __attribute__((visibility("default"))) __attribute__((used))
#endif

BABYLON_EXPORT int babylon_g2p_init(const char* model_path, const char* language, int use_punctuation);
typedef struct {
const char* language;
const unsigned char use_dictionaries;
const unsigned char use_punctuation;
} babylon_g2p_options_t;

BABYLON_EXPORT int babylon_g2p_init(const char* model_path, babylon_g2p_options_t options);

BABYLON_EXPORT char* babylon_g2p(const char* text);

Expand Down
8 changes: 5 additions & 3 deletions include/babylon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@ namespace DeepPhonemizer {

class Session {
public:
Session(const std::string& model_path, const std::string language = "en_us", const bool use_punctuation = false);
Session(const std::string& model_path, const std::string language = "en_us", const bool use_dictionaries = true, const bool use_punctuation = false);
~Session();

std::vector<std::string> g2p(const std::string& text);
std::vector<int64_t> g2p_tokens(const std::string& text);

private:
std::string lang;
bool punctuation;
std::string language;
bool use_dictionaries;
bool use_punctuation;
Ort::Session* session;
SequenceTokenizer* text_tokenizer;
SequenceTokenizer* phoneme_tokenizer;
std::unordered_map<std::string, std::unordered_map<std::string, std::vector<std::string>>> dictionaries;

std::vector<int64_t> g2p_tokens_internal(const std::string& text);
};
Expand Down
Binary file modified models/deep_phonemizer.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion scripts/deep_phonemizer/dp_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, text, phonemes=None, start_index=None):

for language in preprocessing['languages']:
language_dict = phoneme_dict[language]
language_dict_str = "\n".join(f"{key}\t{value}" for key, value in language_dict.items())
language_dict_str = "\n".join(f"{key}\t{' '.join(value)}" for key, value in language_dict.items())
metadata[f"{language}_dictionary"] = language_dict_str

print(metadata.keys())
Expand Down
4 changes: 2 additions & 2 deletions src/babylon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ static DeepPhonemizer::Session* dp;
static Vits::Session* vits;

extern "C" {
BABYLON_EXPORT int babylon_g2p_init(const char* model_path, const char* language, int use_punctuation) {
BABYLON_EXPORT int babylon_g2p_init(const char* model_path, babylon_g2p_options_t options) {
try {
dp = new DeepPhonemizer::Session(model_path, language, use_punctuation);
dp = new DeepPhonemizer::Session(model_path, options.language, options.use_dictionaries, options.use_punctuation);
return 0;
}
catch (const std::exception& e) {
Expand Down
57 changes: 49 additions & 8 deletions src/phonemizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ std::vector<float> softmax(const std::vector<float>& logits) {
return probabilities;
}

std::unordered_map<std::string, std::vector<std::string>> process_dictionary(const std::string& dictonary_str) {
std::unordered_map<std::string, std::vector<std::string>> dictionary;
std::istringstream dictionary_stream(dictonary_str);

std::string line;
while (std::getline(dictionary_stream, line)) {
std::stringstream line_stream(line);
std::string word;
line_stream >> word;
std::vector<std::string> phonemes;
std::string phoneme;
while (line_stream >> phoneme) {
phonemes.push_back(phoneme);
}
dictionary[word] = phonemes;
}

return dictionary;
}

namespace DeepPhonemizer {
SequenceTokenizer::SequenceTokenizer(const std::vector<std::string>& symbols, const std::vector<std::string>& languages, int char_repeats, bool lowercase, bool append_start_end)
: char_repeats(char_repeats), lowercase(lowercase), append_start_end(append_start_end), pad_token(" "), end_token("<end>") {
Expand Down Expand Up @@ -140,15 +160,15 @@ namespace DeepPhonemizer {
return -1;
}

Session::Session(const std::string& model_path, const std::string language, const bool use_punctuation) {
Session::Session(const std::string& model_path, const std::string language, const bool use_dictionaries, const bool use_punctuation) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "DeepPhonemizer");
env.DisableTelemetryEvents();

Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

session = new Ort::Session(env, (const ORTCHAR_T *) model_path.c_str(), session_options);
this->session = new Ort::Session(env, (const ORTCHAR_T *) model_path.c_str(), session_options);

// Load metadata from the model
Ort::ModelMetadata model_metadata = session->GetModelMetadata();
Expand Down Expand Up @@ -181,6 +201,14 @@ namespace DeepPhonemizer {
phoneme_symbols.push_back(phoneme_symbol_buffer);
}

if (use_dictionaries) {
for (const auto& lang : languages) {
std::string key = lang + "_dictionary";
std::string dictonary_str = model_metadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get();
this->dictionaries[lang] = process_dictionary(dictonary_str);
}
}

int char_repeats = model_metadata.LookupCustomMetadataMapAllocated("char_repeats", allocator).get()[0] - '0';

bool lowercase = model_metadata.LookupCustomMetadataMapAllocated("lowercase", allocator).get()[0] == '1';
Expand All @@ -189,10 +217,11 @@ namespace DeepPhonemizer {
throw std::runtime_error("Language not supported.");
}

lang = language;
punctuation = use_punctuation;
text_tokenizer = new SequenceTokenizer(text_symbols, languages, char_repeats, lowercase);
phoneme_tokenizer = new SequenceTokenizer(phoneme_symbols, languages, 1, false);
this->language = language;
this->use_dictionaries = use_dictionaries;
this->use_punctuation = use_punctuation;
this->text_tokenizer = new SequenceTokenizer(text_symbols, languages, char_repeats, lowercase);
this->phoneme_tokenizer = new SequenceTokenizer(phoneme_symbols, languages, 1, false);
}

Session::~Session() {
Expand Down Expand Up @@ -222,7 +251,7 @@ namespace DeepPhonemizer {

phoneme_ids.insert(phoneme_ids.end(), cleaned_word_phoneme_ids.begin(), cleaned_word_phoneme_ids.end());

if (punctuation) {
if (use_punctuation) {
auto back_token = phoneme_tokenizer->get_token(std::string(1, word.back()));

// Check if the word ends with punctuation
Expand All @@ -244,9 +273,21 @@ namespace DeepPhonemizer {

key_text.erase(std::remove_if(key_text.begin(), key_text.end(), ::ispunct), key_text.end());

// First check if word is in the dictionary
if (dictionaries[language].count(key_text) && use_dictionaries) {
auto token_str = dictionaries[language].at(key_text);

std::vector<int64_t> tokens;
for (const auto& token : token_str) {
tokens.push_back(phoneme_tokenizer->get_token(token));
}

return tokens;
}

// Convert input text to tensor
std::vector<Ort::Value> input_tensors;
std::vector<int64_t> input_ids = text_tokenizer->operator()(text, lang);
std::vector<int64_t> input_ids = text_tokenizer->operator()(text, language);

std::vector<int64_t> input_shape = {1, static_cast<int64_t>(input_ids.size())};
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Expand Down

0 comments on commit 590f686

Please sign in to comment.