diff --git a/Sources/Tokenizers/Decoder.swift b/Sources/Tokenizers/Decoder.swift index c98c4b9..67e8659 100644 --- a/Sources/Tokenizers/Decoder.swift +++ b/Sources/Tokenizers/Decoder.swift @@ -23,7 +23,7 @@ extension Decoder { enum DecoderType: String { case Sequence -// case WordPiece + case WordPiece case ByteLevel case Replace case ByteFallback @@ -47,11 +47,47 @@ struct DecoderFactory { case .Fuse : return FuseDecoder(config: config) case .Strip : return StripDecoder(config: config) case .Metaspace : return MetaspaceDecoder(config: config) + case .WordPiece : return WordPieceDecoder(config: config) default : fatalError("Unsupported Decoder type: \(typeName)") } } } +class WordPieceDecoder: Decoder { + let prefix: String + let cleanup: Bool + + required public init(config: Config) { + guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") } + self.prefix = prefix + self.cleanup = config.cleanup?.boolValue ?? false + } + + func decode(tokens: [String]) -> [String] { + var newTokens = [String]() + newTokens.reserveCapacity(tokens.count) + for (index, token) in tokens.enumerated() { + var decodedToken = token + if index != 0 { + if decodedToken.hasPrefix(prefix) { + decodedToken = String(decodedToken.dropFirst(prefix.count)) + } else { + decodedToken = " \(decodedToken)" + } + } + if cleanup { + decodedToken = cleanUpTokenization(decodedToken) + } + newTokens.append(decodedToken) + } + return newTokens + } + + private func cleanUpTokenization(_ token: String) -> String { + return token.trimmingCharacters(in: .whitespacesAndNewlines) + } +} + class DecoderSequence: Decoder { let decoders: [Decoder] diff --git a/Sources/Tokenizers/Normalizer.swift b/Sources/Tokenizers/Normalizer.swift index bff37ab..19acc67 100644 --- a/Sources/Tokenizers/Normalizer.swift +++ b/Sources/Tokenizers/Normalizer.swift @@ -31,6 +31,7 @@ enum NormalizerType: String { case NFKD case NFKC case Bert + case BertNormalizer case Precompiled case StripAccents case Strip @@ -51,7 +52,7 @@ struct NormalizerFactory { case .NFC: return NFCNormalizer(config: config) case .NFKD: return NFKDNormalizer(config: config) case .NFKC: return NFKCNormalizer(config: config) - case .Bert: return BertNormalizer(config: config) + case .Bert, .BertNormalizer: return BertNormalizer(config: config) case .Precompiled: return PrecompiledNormalizer(config: config) case .StripAccents: return StripAccentsNormalizer(config: config) case .Strip: return StripNormalizer(config: config) diff --git a/Sources/Tokenizers/PreTokenizer.swift b/Sources/Tokenizers/PreTokenizer.swift index 9f8dcda..0bcc688 100644 --- a/Sources/Tokenizers/PreTokenizer.swift +++ b/Sources/Tokenizers/PreTokenizer.swift @@ -46,6 +46,7 @@ enum PreTokenizerType: String { case Whitespace case WhitespaceSplit case Metaspace + case BertPreTokenizer // Several more to be supported case Unknown = "" } @@ -63,11 +64,25 @@ struct PreTokenizerFactory { case .Split: return SplitPreTokenizer(config: config) case .Whitespace, .WhitespaceSplit: return WhitespacePreTokenizer(config: config) case .Metaspace: return MetaspacePreTokenizer(config: config) + case .BertPreTokenizer: return BertPreTokenizer(config: config) default: fatalError("Unsupported PreTokenizer type: \(typeName)") } } } +class BertPreTokenizer: PreTokenizer { + let re: String + + required init(config: Config) { + // Ref: https://github.com/huggingface/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1002 + re = "[^\\s\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]" + } + + func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { + return text.ranges(of: re).map { String(text[$0]) } + } +} + class PreTokenizerSequence: PreTokenizer { let preTokenizers: [PreTokenizer] @@ -184,11 +199,10 @@ class ByteLevelPreTokenizer: PreTokenizer { } class PunctuationPreTokenizer: PreTokenizer { - let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"# let re: String required init(config: Config) { - re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+" + re = "[^\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]+" } func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] { diff --git a/Sources/Tokenizers/Utils.swift b/Sources/Tokenizers/Utils.swift index 78687ce..9efacc2 100644 --- a/Sources/Tokenizers/Utils.swift +++ b/Sources/Tokenizers/Utils.swift @@ -77,3 +77,7 @@ struct Utils { } } +enum Constants { + static let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"# +} + diff --git a/Tests/PreTokenizerTests/PreTokenizerTests.swift b/Tests/PreTokenizerTests/PreTokenizerTests.swift index 3e4e064..93d6838 100644 --- a/Tests/PreTokenizerTests/PreTokenizerTests.swift +++ b/Tests/PreTokenizerTests/PreTokenizerTests.swift @@ -171,4 +171,24 @@ class PreTokenizerTests: XCTestCase { ["▁Hey", "▁my", "▁friend", "▁", "▁", "▁how", "▁are", "▁you"] ) } + + func testBertPreTokenizer() { + let preTokenizer1 = BertPreTokenizer(config: Config([:])) + XCTAssertEqual( + preTokenizer1.preTokenize(text: "Hey friend!"), + ["Hey", "friend", "!"] + ) + XCTAssertEqual( + preTokenizer1.preTokenize(text: "Hey friend! How are you?!?"), + ["Hey", "friend", "!", "How", "are", "you", "?", "!", "?"] + ) + XCTAssertEqual( + preTokenizer1.preTokenize(text: " Hey, friend , what's up? "), + ["Hey", ",", "friend", ",", "what", "\'", "s", "up", "?"] + ) + XCTAssertEqual( + preTokenizer1.preTokenize(text: " Hey, friend , 0 99 what's up? "), + ["Hey", ",", "friend", ",", "0", "99", "what", "\'", "s", "up", "?"] + ) + } }