Skip to content

Commit

Permalink
Subword regularization with SentencePiece (#27)
Browse files Browse the repository at this point in the history
* add support Subword regularization

* add test case for 1-best

* update CMakefile
refer to monsieurzhang/sentencepiece@d1d6efd

* Revert changes to CMakeLists.txt

* Remove JSON dependency and add a SP-specific constructor

* Revise SP-specific constructor

* nbest_size can be negative

* Align alpha type with the SP interface

* Do not break compilation with older SentencePiece installation

* Update client and Python bindings

* Update changelog
  • Loading branch information
monsieurzhang authored and guillaumekln committed Aug 9, 2018
1 parent b799e7c commit 83f1737
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 45 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

### New features

* Support SentencePiece sampling API

### Fixes and improvements

## [v1.6.1](https://github.com/OpenNMT/Tokenizer/releases/tag/v1.6.1) (2018-07-31)
Expand Down
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ else()
list(APPEND SOURCES src/SentencePiece.cc)
list(APPEND INCLUDE_DIRECTORIES ${SP_INCLUDE_DIR})
list(APPEND LINK_LIBRARIES ${SP_LIBRARY})

file(STRINGS ${SP_INCLUDE_DIR}/sentencepiece_processor.h HAS_SAMPLE_ENCODE REGEX "SampleEncode")
if(HAS_SAMPLE_ENCODE)
add_definitions(-DSP_HAS_SAMPLE_ENCODE)
endif()
endif()

add_library(${PROJECT_NAME} ${SOURCES})
Expand Down
27 changes: 13 additions & 14 deletions bindings/python/Python.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#define BOOST_PYTHON_MAX_ARITY 19
#define BOOST_PYTHON_MAX_ARITY 21
#include <boost/python.hpp>
#include <boost/python/stl_iterator.hpp>

Expand Down Expand Up @@ -33,6 +33,8 @@ class TokenizerWrapper
const std::string& bpe_vocab_path,
int bpe_vocab_threshold,
const std::string& sp_model_path,
int sp_nbest_size,
float sp_alpha,
const std::string& joiner,
bool joiner_annotate,
bool joiner_new,
Expand Down Expand Up @@ -71,19 +73,14 @@ class TokenizerWrapper
if (segment_alphabet_change)
flags |= onmt::Tokenizer::Flags::SegmentAlphabetChange;

std::string model_path;
onmt::Tokenizer::Mode tok_mode = onmt::Tokenizer::mapMode.at(mode);

if (!bpe_model_path.empty())
model_path = bpe_model_path;
else if (!sp_model_path.empty())
{
flags |= onmt::Tokenizer::Flags::SentencePieceModel;
model_path = sp_model_path;
}

_tokenizer = new onmt::Tokenizer(onmt::Tokenizer::mapMode.at(mode),
flags, model_path, joiner,
bpe_vocab_path, bpe_vocab_threshold);
if (!sp_model_path.empty())
_tokenizer = new onmt::Tokenizer(sp_model_path, sp_nbest_size, sp_alpha,
tok_mode, flags, joiner);
else
_tokenizer = new onmt::Tokenizer(tok_mode, flags, bpe_model_path, joiner,
bpe_vocab_path, bpe_vocab_threshold);

for (auto it = py::stl_input_iterator<std::string>(segment_alphabet);
it != py::stl_input_iterator<std::string>(); it++)
Expand Down Expand Up @@ -139,11 +136,13 @@ BOOST_PYTHON_MODULE(tokenizer)
{
py::class_<TokenizerWrapper>(
"Tokenizer",
py::init<std::string, std::string, std::string, int, std::string, std::string, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, py::list>(
py::init<std::string, std::string, std::string, int, std::string, int, float, std::string, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, bool, py::list>(
(py::arg("bpe_model_path")="",
py::arg("bpe_vocab_path")="",
py::arg("bpe_vocab_threshold")=50,
py::arg("sp_model_path")="",
py::arg("sp_nbest_size")=0,
py::arg("sp_alpha")=0.1,
py::arg("joiner")=onmt::Tokenizer::joiner_marker,
py::arg("joiner_annotate")=false,
py::arg("joiner_new")=false,
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ tokenizer = pyonmt.Tokenizer(
bpe_vocab_path="",
bpe_vocab_threshold=50,
sp_model_path="",
sp_nbest_size=0,
sp_alpha=0.1,
joiner="",
joiner_annotate=False,
joiner_new=False,
Expand Down
36 changes: 19 additions & 17 deletions cli/tokenize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ int main(int argc, char* argv[])
("bpe_vocab_threshold", po::value<int>()->default_value(50), "Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV.")
#ifdef WITH_SP
("sp_model,sp", po::value<std::string>()->default_value(""), "path to the SentencePiece model")
("sp_nbest_size", po::value<int>()->default_value(0), "number of candidates for the SentencePiece sampling API")
("sp_alpha", po::value<float>()->default_value(0.1), "smoothing parameter for the SentencePiece sampling API")
#endif
;

Expand Down Expand Up @@ -70,28 +72,28 @@ int main(int argc, char* argv[])
vm["segment_alphabet"].as<std::string>(),
boost::is_any_of(","));

std::string model_path;
onmt::Tokenizer* tokenizer = nullptr;

if (!vm["bpe_model"].as<std::string>().empty())
model_path = vm["bpe_model"].as<std::string>();
#ifdef WITH_SP
else if (!vm["sp_model"].as<std::string>().empty())
if (!vm["sp_model"].as<std::string>().empty())
{
flags |= onmt::Tokenizer::Flags::SentencePieceModel;
model_path = vm["sp_model"].as<std::string>();
tokenizer = new onmt::Tokenizer(vm["sp_model"].as<std::string>(),
vm["sp_nbest_size"].as<int>(),
vm["sp_alpha"].as<float>(),
onmt::Tokenizer::mapMode.at(vm["mode"].as<std::string>()),
flags,
vm["joiner"].as<std::string>());
}
else
#endif

std::string bpe_vocab_path = vm["bpe_vocab"].as<std::string>();
int bpe_vocab_threshold = vm["bpe_vocab_threshold"].as<int>();

onmt::Tokenizer* tokenizer = new onmt::Tokenizer(
onmt::Tokenizer::mapMode.at(vm["mode"].as<std::string>()),
flags,
model_path,
vm["joiner"].as<std::string>(),
bpe_vocab_path,
bpe_vocab_threshold);
{
tokenizer = new onmt::Tokenizer(onmt::Tokenizer::mapMode.at(vm["mode"].as<std::string>()),
flags,
vm["bpe_model"].as<std::string>(),
vm["joiner"].as<std::string>(),
vm["bpe_vocab"].as<std::string>(),
vm["bpe_vocab_threshold"].as<int>());
}

for (const auto& alphabet : alphabets_to_segment)
{
Expand Down
8 changes: 8 additions & 0 deletions docs/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ When using `bpe_vocab`, any words with a frequency lower than `bpe_vocab_thresho

Path to the SentencePiece model. To replicate `spm_encode`, the tokenization mode should be `none`.

### `sp_nbest_size` (int, default: `0`)

Number of candidates for the SentencePiece sampling API. When the value is 0, the standard SentencePiece encoding is used.

### `sp_alpha` (float, default: `0.1`)

Smoothing parameter for the SentencePiece sampling API.

## Marking joint tokens

These options inject characters to make the tokenization reversible.
Expand Down
4 changes: 4 additions & 0 deletions include/onmt/SentencePiece.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ namespace onmt
public:
SentencePiece(const std::string& model_path);

void enable_regularization(int nbest_size, float alpha);

std::vector<std::string> encode(const std::string& str) const override;
std::vector<AnnotatedToken> encode_and_annotate(const AnnotatedToken& token) const override;

private:
sentencepiece::SentencePieceProcessor _processor;
int _nbest_size;
float _alpha;
};

}
10 changes: 10 additions & 0 deletions include/onmt/Tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ namespace onmt
const std::string& joiner = joiner_marker,
const std::string& bpe_vocab_path = "",
int bpe_vocab_threshold = 50);

// SentencePiece-specific constructor.
Tokenizer(const std::string& sp_model_path,
int sp_nbest_size = 0,
float sp_alpha = 0.1,
Mode mode = Mode::None,
int flags = Flags::None,
const std::string& joiner = joiner_marker);

~Tokenizer();

void tokenize(const std::string& text,
Expand Down Expand Up @@ -107,6 +116,7 @@ namespace onmt

std::set<int> _segment_alphabet;

void read_flags(int flags);
std::vector<AnnotatedToken> encode_subword(const std::vector<AnnotatedToken>& tokens) const;
void finalize_tokens(std::vector<AnnotatedToken>& annotated_tokens,
std::vector<std::string>& tokens) const;
Expand Down
19 changes: 18 additions & 1 deletion src/SentencePiece.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,31 @@ namespace onmt
static const std::string sp_marker("");

SentencePiece::SentencePiece(const std::string& model_path)
: _nbest_size(0)
, _alpha(0.0)
{
_processor.Load(model_path);
}

void SentencePiece::enable_regularization(int nbest_size, float alpha)
{
_nbest_size = nbest_size;
_alpha = alpha;
}

std::vector<std::string> SentencePiece::encode(const std::string& str) const
{
std::vector<std::string> pieces;
_processor.Encode(str, &pieces);

#ifdef SP_HAS_SAMPLE_ENCODE
if (_nbest_size != 0)
_processor.SampleEncode(str, _nbest_size, _alpha, &pieces);
else
#endif
{
_processor.Encode(str, &pieces);
}

return pieces;
}

Expand Down
55 changes: 42 additions & 13 deletions src/Tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,10 @@ namespace onmt
const std::string& bpe_vocab_path,
int bpe_vocab_threshold)
: _mode(mode)
, _case_feature(flags & Flags::CaseFeature)
, _joiner_annotate(flags & Flags::JoinerAnnotate)
, _joiner_new(flags & Flags::JoinerNew)
, _with_separators(flags & Flags::WithSeparators)
, _segment_case(flags & Flags::SegmentCase)
, _segment_numbers(flags & Flags::SegmentNumbers)
, _segment_alphabet_change(flags & Flags::SegmentAlphabetChange)
, _cache_model((flags & Flags::CacheBPEModel) | (flags & Flags::CacheModel))
, _no_substitution(flags & Flags::NoSubstitution)
, _spacer_annotate(flags & Flags::SpacerAnnotate)
, _spacer_new(flags & Flags::SpacerNew)
, _preserve_placeholders(flags & Flags::PreservePlaceholders)
, _preserve_segmented_tokens(flags & Flags::PreserveSegmentedTokens)
, _subword_encoder(nullptr)
, _joiner(joiner)
{
read_flags(flags);
if (flags & Flags::SentencePieceModel)
#ifdef WITH_SP
set_sp_model(model_path, _cache_model);
Expand All @@ -101,6 +89,47 @@ namespace onmt
}
}

Tokenizer::Tokenizer(const std::string& sp_model_path,
int sp_nbest_size,
float sp_alpha,
Mode mode,
int flags,
const std::string& joiner)
: _mode(mode)
, _subword_encoder(nullptr)
, _joiner(joiner)
{
#ifndef WITH_SP
throw std::runtime_error("The Tokenizer was not built with SentencePiece support");
#else
read_flags(flags);
set_sp_model(sp_model_path, _cache_model);
if (sp_nbest_size != 0)
# ifdef SP_HAS_SAMPLE_ENCODE
((SentencePiece*)_subword_encoder)->enable_regularization(sp_nbest_size, sp_alpha);
# else
throw std::runtime_error("This version of SentencePiece does not include the sampling API");
# endif
#endif
}

void Tokenizer::read_flags(int flags)
{
_case_feature = flags & Flags::CaseFeature;
_joiner_annotate = flags & Flags::JoinerAnnotate;
_joiner_new = flags & Flags::JoinerNew;
_with_separators = flags & Flags::WithSeparators;
_segment_case = flags & Flags::SegmentCase;
_segment_numbers = flags & Flags::SegmentNumbers;
_segment_alphabet_change = flags & Flags::SegmentAlphabetChange;
_cache_model = (flags & Flags::CacheBPEModel) | (flags & Flags::CacheModel);
_no_substitution = flags & Flags::NoSubstitution;
_spacer_annotate = flags & Flags::SpacerAnnotate;
_spacer_new = flags & Flags::SpacerNew;
_preserve_placeholders = flags & Flags::PreservePlaceholders;
_preserve_segmented_tokens = flags & Flags::PreserveSegmentedTokens;
}

Tokenizer::~Tokenizer()
{
if (!_cache_model)
Expand Down
Binary file added test/data/sp-models/sp_regularization.model
Binary file not shown.
10 changes: 10 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ TEST(TokenizerTest, SentencePiece) {
"The ▁two ▁shows , ▁called ▁De si re ▁and ▁S e c re t s , ▁will ▁be ▁one - hour ▁prime - time ▁shows .");
}

#ifdef SP_HAS_SAMPLE_ENCODE
TEST(TokenizerTest, SentencePieceSubwordRegularization) {
auto tokenizer = std::unique_ptr<ITokenizer>(
new Tokenizer(get_data("sp-models/sp_regularization.model"), 1, 0.1));
test_tok_and_detok(tokenizer,
"The two shows, called Desire and Secrets, will be one-hour prime-time shows.",
"▁The ▁ two ▁show s , ▁call ed ▁De si re ▁ and ▁Sec re t s , ▁w ill ▁be ▁one - h our ▁ pri me - t im e ▁show s .");
}
#endif

TEST(TokenizerTest, SentencePieceAlt) {
auto tokenizer = std::unique_ptr<ITokenizer>(
new Tokenizer(Tokenizer::Mode::None, Tokenizer::Flags::SentencePieceModel,
Expand Down

0 comments on commit 83f1737

Please sign in to comment.