Skip to content

Commit

Permalink
Add a Vocab class and related functions (#281)
Browse files Browse the repository at this point in the history
* Add a Vocab class and related functions

* Add export directive for MSVC

* Define __call__ method for batch lookup

* Fix typo in README

* Improve formatting

* Remove currently unused static methods

* Ensure special tokens are not removed by a resize

* Update README

* Rename variable for consistency

* Allow overriding the ID returned for OOV tokens

* Improve typing in help output

* Use maxint instead of -1 for default value
  • Loading branch information
guillaumekln authored Mar 7, 2022
1 parent fee8c9a commit 45fc569
Show file tree
Hide file tree
Showing 8 changed files with 458 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ The project follows [semantic versioning 2.0.0](https://semver.org/). The API co
* `onmt::SentencePiece`
* `onmt::SpaceTokenizer`
* `onmt::Tokenizer`
* `onmt::Vocab`
* `onmt::unicode::*`
* Python
* `pyonmttok.BPELearner`
* `pyonmttok.SentencePieceLearner`
* `pyonmttok.SentencePieceTokenizer`
* `pyonmttok.Tokenizer`
* `pyonmttok.Vocab`

---

Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ set(PUBLIC_HEADERS
include/onmt/SubwordEncoder.h
include/onmt/SubwordLearner.h
include/onmt/Tokenizer.h
include/onmt/Vocab.h
)

set(SOURCES
Expand All @@ -88,6 +89,7 @@ set(SOURCES
src/Token.cc
src/Tokenizer.cc
src/Utils.cc
src/Vocab.cc
src/unicode/Unicode.cc
)

Expand Down
73 changes: 73 additions & 0 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pip install pyonmttok

1. [Tokenization](#tokenization)
1. [Subword learning](#subword-learning)
1. [Vocabulary](#vocabulary)
1. [Token API](#token-api)
1. [Utilities](#utilities)

Expand Down Expand Up @@ -214,6 +215,78 @@ learner.ingest_token(token: Union[str, pyonmttok.Token])
learner.learn(model_path: str, verbose: bool = False) -> pyonmttok.Tokenizer
```

## Vocabulary

### Example

```python
tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)

with open("train.txt") as train_file:
vocab = pyonmttok.build_vocab_from_lines(
train_file,
tokenizer=tokenizer,
maximum_size=32000,
special_tokens=["<blank>", "<unk>", "<s>", "</s>"],
)

with open("vocab.txt", "w") as vocab_file:
for token in vocab.ids_to_tokens:
vocab_file.write("%s\n" % token)
```

### Interface

```python
# Special tokens are added with ids 0, 1, etc., and are never removed by a resize.
vocab = pyonmttok.Vocab(special_tokens: Optional[List[str]] = None)

# Read-only properties.
vocab.tokens_to_ids -> Dict[str, int]
vocab.ids_to_tokens -> List[str]

# Get or set the ID returned for out-of-vocabulary tokens.
# By default, it is the ID of the token <unk> if present in the vocabulary, len(vocab) otherwise.
vocab.default_id -> int

vocab.lookup_token(token: str) -> int
vocab.lookup_index(index: int) -> str

# Calls lookup_token on a batch of tokens.
vocab.__call__(tokens: List[str]) -> List[int]

vocab.__len__() -> int # Implements: len(vocab)
vocab.__contains__(token: str) -> bool # Implements: "hello" in vocab
vocab.__getitem__(token: str) -> int # Implements: vocab["hello"]

# Add tokens to the vocabulary after tokenization.
# If a tokenizer is not set, the text is split on spaces.
vocab.add_from_text(text: str, tokenizer: Optional[pyonmttok.Tokenizer] = None) -> None
vocab.add_from_file(path: str, tokenizer: Optional[pyonmttok.Tokenizer] = None) -> None
vocab.add_token(token: str) -> None

vocab.resize(maximum_size: int = 0, minimum_frequency: int = 1) -> None


# Build a vocabulary from an iterator of lines.
# If a tokenizer is not set, the lines are split on spaces.
pyonmttok.build_vocab_from_lines(
lines: Iterable[str],
tokenizer: Optional[pyonmttok.Tokenizer] = None,
maximum_size: int = 0,
minimum_frequency: int = 1,
special_tokens: Optional[List[str]] = None,
) -> pyonmttok.Vocab

# Build a vocabulary from an iterator of tokens.
pyonmttok.build_vocab_from_tokens(
tokens: Iterable[str],
maximum_size: int = 0,
minimum_frequency: int = 1,
special_tokens: Optional[List[str]] = None,
) -> pyonmttok.Vocab
```

## Token API

The Token API allows to tokenize text into `pyonmttok.Token` objects. This API can be useful to apply some logics at the token level but still retain enough information to write the tokenization on disk or detokenize.
Expand Down
73 changes: 73 additions & 0 deletions bindings/python/pyonmttok/Python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <onmt/SentencePiece.h>
#include <onmt/BPELearner.h>
#include <onmt/SentencePieceLearner.h>
#include <onmt/Vocab.h>

namespace py = pybind11;
using namespace pybind11::literals;
Expand Down Expand Up @@ -481,6 +482,13 @@ static py::ssize_t hash_token(const onmt::Token& token) {
py::tuple(py::cast(token.features))));
}

static onmt::Vocab create_vocab(const std::optional<std::vector<std::string>>& special_tokens) {
if (special_tokens)
return onmt::Vocab(special_tokens.value());
else
return onmt::Vocab();
}

PYBIND11_MODULE(_ext, m)
{
m.def("is_placeholder", &onmt::Tokenizer::is_placeholder, py::arg("token"));
Expand Down Expand Up @@ -719,4 +727,69 @@ PYBIND11_MODULE(_ext, m)
py::arg("tokenizer")=py::none(),
py::arg("keep_vocab")=false)
;

py::class_<onmt::Vocab>(m, "Vocab")
.def(py::init(&create_vocab), py::arg("special_tokens")=py::none())
.def("__len__", &onmt::Vocab::size)
.def("__contains__", &onmt::Vocab::contains, py::arg("token"))
.def("__getitem__", py::overload_cast<const std::string&>(&onmt::Vocab::lookup, py::const_),
py::arg("token"))
.def("lookup_token", py::overload_cast<const std::string&>(&onmt::Vocab::lookup, py::const_),
py::arg("token"))
.def("lookup_index", py::overload_cast<size_t>(&onmt::Vocab::lookup, py::const_),
py::arg("index"))
.def("__call__",
[](const onmt::Vocab& vocab, const std::vector<std::string>& tokens) {
std::vector<size_t> ids;
ids.reserve(tokens.size());
for (const auto& token : tokens)
ids.emplace_back(vocab.lookup(token));
return ids;
},
py::arg("tokens"),
py::call_guard<py::gil_scoped_release>())

.def_property("default_id", &onmt::Vocab::get_default_id, &onmt::Vocab::set_default_id)
.def_property_readonly("tokens_to_ids", &onmt::Vocab::tokens_to_ids)
.def_property_readonly("ids_to_tokens", &onmt::Vocab::ids_to_tokens)

.def("add_token", &onmt::Vocab::add_token, py::arg("token"))

.def("add_from_text",
[](onmt::Vocab& vocab,
const std::string& text,
const std::optional<TokenizerWrapper>& tokenizer) {
vocab.add_from_text(text, tokenizer ? tokenizer.value().get().get() : nullptr);
},
py::arg("text"),
py::arg("tokenizer")=nullptr,
py::call_guard<py::gil_scoped_release>())

.def("add_from_file",
[](onmt::Vocab& vocab,
const std::string& path,
const std::optional<TokenizerWrapper>& tokenizer) {
std::ifstream in(path);
if (!in)
throw std::invalid_argument("Failed to open input file " + path);
vocab.add_from_stream(in, tokenizer ? tokenizer.value().get().get() : nullptr);
},
py::arg("path"),
py::arg("tokenizer")=nullptr,
py::call_guard<py::gil_scoped_release>())

.def("resize", &onmt::Vocab::resize,
py::arg("maximum_size")=0,
py::arg("minimum_frequency")=1,
py::call_guard<py::gil_scoped_release>())

.def("__copy__",
[](const onmt::Vocab& vocab) {
return onmt::Vocab(vocab);
})
.def("__deepcopy__",
[](const onmt::Vocab& vocab, const py::object& dict) {
return onmt::Vocab(vocab);
})
;
}
28 changes: 28 additions & 0 deletions bindings/python/pyonmttok/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,35 @@
Token,
Tokenizer,
TokenType,
Vocab,
is_placeholder,
set_random_seed,
)
from pyonmttok.version import __version__


def build_vocab_from_tokens(
tokens,
maximum_size=0,
minimum_frequency=1,
special_tokens=None,
):
vocab = Vocab(special_tokens)
for token in tokens:
vocab.add_token(token)
vocab.resize(maximum_size=maximum_size, minimum_frequency=minimum_frequency)
return vocab


def build_vocab_from_lines(
lines,
tokenizer=None,
maximum_size=0,
minimum_frequency=1,
special_tokens=None,
):
vocab = Vocab(special_tokens)
for line in lines:
vocab.add_from_text(line.rstrip("\r\n"), tokenizer)
vocab.resize(maximum_size=maximum_size, minimum_frequency=minimum_frequency)
return vocab
93 changes: 93 additions & 0 deletions bindings/python/test/test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import itertools
import os
import pickle

Expand Down Expand Up @@ -527,3 +528,95 @@ def test_token_pickle():
data = pickle.dumps(token)
token2 = pickle.loads(data)
assert token == token2


def test_vocab():
special_tokens = ["<blank>", "<s>", "</s>"]
vocab = pyonmttok.Vocab(special_tokens=special_tokens)
vocab.add_token("a")
vocab.add_token("a")
vocab.add_token("b")

assert len(vocab) == 5
assert "a" in vocab
assert "b" in vocab
assert "c" not in vocab
assert vocab["<blank>"] == 0
assert vocab["a"] == 3
assert vocab["b"] == 4
assert vocab["c"] == len(vocab)
assert vocab.lookup_index(len(vocab)) == "<unk>"
assert vocab(["a", "b"]) == [3, 4]

assert vocab.tokens_to_ids == {
"<blank>": 0,
"<s>": 1,
"</s>": 2,
"a": 3,
"b": 4,
}

vocab1 = copy.deepcopy(vocab)
vocab2 = copy.deepcopy(vocab)

vocab1.resize(maximum_size=4)
vocab2.resize(minimum_frequency=2)

expected_tokens = list(special_tokens) + ["a"]
assert vocab1.ids_to_tokens == expected_tokens
assert vocab2.ids_to_tokens == expected_tokens

vocab3 = copy.deepcopy(vocab)
vocab3.resize(maximum_size=1)
assert vocab3.ids_to_tokens == special_tokens


def test_vocab_from_text():
vocab = pyonmttok.Vocab()
vocab.add_from_text("Hello World!")
assert vocab.ids_to_tokens == ["Hello", "World!"]

tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)
vocab = pyonmttok.Vocab()
vocab.add_from_text("Hello World!", tokenizer)
assert vocab.ids_to_tokens == ["Hello", "World", "■!"]


def test_vocab_from_file(tmpdir):
input_path = str(tmpdir.join("input.txt"))
with open(input_path, "w", encoding="utf-8") as input_file:
input_file.write("Hello World!\n")

tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)
vocab = pyonmttok.Vocab()
vocab.add_from_file(input_path, tokenizer)
assert vocab.ids_to_tokens == ["Hello", "World", "■!"]


def test_vocab_build_helpers():
tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)
lines = ["Hello World!", "Hello all."]
tokens = itertools.chain.from_iterable(map(tokenizer, lines))

vocab1 = pyonmttok.build_vocab_from_lines(lines, tokenizer)
vocab2 = pyonmttok.build_vocab_from_tokens(tokens)

expected_tokens = ["Hello", "World", "■!", "all", "■."]
assert vocab1.ids_to_tokens == expected_tokens
assert vocab2.ids_to_tokens == expected_tokens


@pytest.mark.parametrize(
"tokens,default_id,expected_default_id",
[
(["a", "b", "c"], None, 3),
(["a", "b", "c"], 1, 1),
(["a", "<unk>", "b", "c"], None, 1),
],
)
def test_vocab_default_id(tokens, default_id, expected_default_id):
vocab = pyonmttok.build_vocab_from_tokens(tokens)
if default_id is not None:
vocab.default_id = default_id
assert vocab.default_id == expected_default_id
assert vocab.lookup_token("oov") == expected_default_id
Loading

0 comments on commit 45fc569

Please sign in to comment.