Skip to content

Commit

Permalink
Add WordPieceDecoder tests that match Tokenizers (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
shavit authored Dec 12, 2024
1 parent 5751308 commit 2f611bf
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
28 changes: 11 additions & 17 deletions Sources/Tokenizers/Decoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,28 @@ class WordPieceDecoder: Decoder {
let prefix: String
let cleanup: Bool

// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L31
private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'|n't|'m|'s|'ve|'re)", options: [])

required public init(config: Config) {
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
self.prefix = prefix
self.cleanup = config.cleanup?.boolValue ?? false
}

func decode(tokens: [String]) -> [String] {
var newTokens = [String]()
newTokens.reserveCapacity(tokens.count)
for (index, token) in tokens.enumerated() {
var decodedToken = token
if index != 0 {
if decodedToken.hasPrefix(prefix) {
decodedToken = String(decodedToken.dropFirst(prefix.count))
} else {
decodedToken = " \(decodedToken)"
}
}
if cleanup {
decodedToken = cleanUpTokenization(decodedToken)
}
newTokens.append(decodedToken)
let firstToken = cleanup ? cleanUpTokenization(tokens.first!) : tokens.first!
return [firstToken] + tokens.dropFirst().map { token in
let token = token.hasPrefix(prefix) ? token.replacingCharacters(in: token.range(of: prefix)!, with: "") : " \(token)"
return cleanup ? cleanUpTokenization(token) : token
}
return newTokens
}

// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L40
private func cleanUpTokenization(_ token: String) -> String {
return token.trimmingCharacters(in: .whitespacesAndNewlines)
let range = NSRange(location: 0, length: token.utf16.count)
return re.stringByReplacingMatches(in: token, options: [], range: range, withTemplate: "$1")
.replacingOccurrences(of: " do not", with: " don't")
}
}

Expand Down
19 changes: 19 additions & 0 deletions Tests/TokenizersTests/DecoderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,23 @@ class DecoderTests: XCTestCase {
["Hey", " my", " friend", " ", " <s>", " how", " are", " you"]
)
}

func testWordPieceDecoder() {
let config = Config(["prefix": "##", "cleanup": true])
let decoder = WordPieceDecoder(config: config)

let testCases: [([String], String)] = [
(["##inter", "##national", "##ization"], "##internationalization"),
(["##auto", "##mat", "##ic", "transmission"], "##automatic transmission"),
(["who", "do", "##n't", "does", "n't", "can't"], "who don't doesn't can't"),
(["##un", "##believ", "##able", "##fa", "##ntastic"], "##unbelievablefantastic"),
(["this", "is", "un", "##believ", "##able", "fa", "##ntastic"], "this is unbelievable fantastic"),
(["The", "##quick", "##brown", "fox"], "Thequickbrown fox"),
]

for (tokens, expected) in testCases {
let output = decoder.decode(tokens: tokens)
XCTAssertEqual(output.joined(), expected)
}
}
}

0 comments on commit 2f611bf

Please sign in to comment.