Skip to content

Commit

Permalink
Added support for Bert models (#137)
Browse files Browse the repository at this point in the history
Co-authored-by: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com>
  • Loading branch information
jkrukowski and ashvardanian authored Oct 30, 2024
1 parent 4d25d20 commit 2c68d53
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 4 deletions.
38 changes: 37 additions & 1 deletion Sources/Tokenizers/Decoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ extension Decoder {

enum DecoderType: String {
case Sequence
// case WordPiece
case WordPiece
case ByteLevel
case Replace
case ByteFallback
Expand All @@ -47,11 +47,47 @@ struct DecoderFactory {
case .Fuse : return FuseDecoder(config: config)
case .Strip : return StripDecoder(config: config)
case .Metaspace : return MetaspaceDecoder(config: config)
case .WordPiece : return WordPieceDecoder(config: config)
default : fatalError("Unsupported Decoder type: \(typeName)")
}
}
}

class WordPieceDecoder: Decoder {
let prefix: String
let cleanup: Bool

required public init(config: Config) {
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
self.prefix = prefix
self.cleanup = config.cleanup?.boolValue ?? false
}

func decode(tokens: [String]) -> [String] {
var newTokens = [String]()
newTokens.reserveCapacity(tokens.count)
for (index, token) in tokens.enumerated() {
var decodedToken = token
if index != 0 {
if decodedToken.hasPrefix(prefix) {
decodedToken = String(decodedToken.dropFirst(prefix.count))
} else {
decodedToken = " \(decodedToken)"
}
}
if cleanup {
decodedToken = cleanUpTokenization(decodedToken)
}
newTokens.append(decodedToken)
}
return newTokens
}

private func cleanUpTokenization(_ token: String) -> String {
return token.trimmingCharacters(in: .whitespacesAndNewlines)
}
}

class DecoderSequence: Decoder {
let decoders: [Decoder]

Expand Down
3 changes: 2 additions & 1 deletion Sources/Tokenizers/Normalizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum NormalizerType: String {
case NFKD
case NFKC
case Bert
case BertNormalizer
case Precompiled
case StripAccents
case Strip
Expand All @@ -51,7 +52,7 @@ struct NormalizerFactory {
case .NFC: return NFCNormalizer(config: config)
case .NFKD: return NFKDNormalizer(config: config)
case .NFKC: return NFKCNormalizer(config: config)
case .Bert: return BertNormalizer(config: config)
case .Bert, .BertNormalizer: return BertNormalizer(config: config)
case .Precompiled: return PrecompiledNormalizer(config: config)
case .StripAccents: return StripAccentsNormalizer(config: config)
case .Strip: return StripNormalizer(config: config)
Expand Down
18 changes: 16 additions & 2 deletions Sources/Tokenizers/PreTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum PreTokenizerType: String {
case Whitespace
case WhitespaceSplit
case Metaspace
case BertPreTokenizer
// Several more to be supported
case Unknown = ""
}
Expand All @@ -63,11 +64,25 @@ struct PreTokenizerFactory {
case .Split: return SplitPreTokenizer(config: config)
case .Whitespace, .WhitespaceSplit: return WhitespacePreTokenizer(config: config)
case .Metaspace: return MetaspacePreTokenizer(config: config)
case .BertPreTokenizer: return BertPreTokenizer(config: config)
default: fatalError("Unsupported PreTokenizer type: \(typeName)")
}
}
}

class BertPreTokenizer: PreTokenizer {
let re: String

required init(config: Config) {
// Ref: https://github.com/huggingface/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1002
re = "[^\\s\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]"
}

func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
return text.ranges(of: re).map { String(text[$0]) }
}
}

class PreTokenizerSequence: PreTokenizer {
let preTokenizers: [PreTokenizer]

Expand Down Expand Up @@ -184,11 +199,10 @@ class ByteLevelPreTokenizer: PreTokenizer {
}

class PunctuationPreTokenizer: PreTokenizer {
let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
let re: String

required init(config: Config) {
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
re = "[^\(Constants.PUNCTUATION_REGEX)]+|[\(Constants.PUNCTUATION_REGEX)]+"
}

func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
Expand Down
4 changes: 4 additions & 0 deletions Sources/Tokenizers/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ struct Utils {
}
}

enum Constants {
static let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
}

20 changes: 20 additions & 0 deletions Tests/PreTokenizerTests/PreTokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,24 @@ class PreTokenizerTests: XCTestCase {
["▁Hey", "▁my", "▁friend", "", "▁<s>", "▁how", "▁are", "▁you"]
)
}

func testBertPreTokenizer() {
let preTokenizer1 = BertPreTokenizer(config: Config([:]))
XCTAssertEqual(
preTokenizer1.preTokenize(text: "Hey friend!"),
["Hey", "friend", "!"]
)
XCTAssertEqual(
preTokenizer1.preTokenize(text: "Hey friend! How are you?!?"),
["Hey", "friend", "!", "How", "are", "you", "?", "!", "?"]
)
XCTAssertEqual(
preTokenizer1.preTokenize(text: " Hey, friend , what's up? "),
["Hey", ",", "friend", ",", "what", "\'", "s", "up", "?"]
)
XCTAssertEqual(
preTokenizer1.preTokenize(text: " Hey, friend , 0 99 what's up? "),
["Hey", ",", "friend", ",", "0", "99", "what", "\'", "s", "up", "?"]
)
}
}

0 comments on commit 2c68d53

Please sign in to comment.