Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting more BERT-like models #89

Closed
wants to merge 10 commits into from
14 changes: 13 additions & 1 deletion Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,19 @@ public struct Config {
}

/// Tuple of token identifier and string value
public var tokenValue: (UInt, String)? { value as? (UInt, String) }
public var tokenValue: (UInt, String)? {
guard let array = value as? [Any], array.count == 2 else {
return nil
}

if let first = array[0] as? String, let second = array[1] as? Int64 {
return (UInt(second), first)
} else if let first = array[0] as? Int64, let second = array[1] as? String {
return (UInt(first), second)
}

return nil
}
pcuenca marked this conversation as resolved.
Show resolved Hide resolved
}

public class LanguageModelConfigurationFromHub {
Expand Down
47 changes: 42 additions & 5 deletions Sources/Tokenizers/Decoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ extension Decoder {

enum DecoderType: String {
case Sequence
// case WordPiece
case WordPiece
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

case ByteLevel
case Replace
case ByteFallback
Expand All @@ -37,7 +37,11 @@ struct DecoderFactory {
static func fromConfig(config: Config?, addedTokens: Set<String>? = nil) -> Decoder? {
// TODO: not sure if we need to include `addedTokens` in all the decoder initializers (and the protocol)
guard let config = config else { return nil }
guard let typeName = config.type?.stringValue else { return nil }
guard var typeName = config.type?.stringValue else { return nil }
if typeName.hasSuffix("Decoder") {
typeName = String(typeName.dropLast("Decoder".count))
}

let type = DecoderType(rawValue: typeName)
switch type {
case .Sequence : return DecoderSequence(config: config)
Expand All @@ -47,6 +51,7 @@ struct DecoderFactory {
case .Fuse : return FuseDecoder(config: config)
case .Strip : return StripDecoder(config: config)
case .Metaspace : return MetaspaceDecoder(config: config)
case .WordPiece : return WordPieceDecoder(config: config)
default : fatalError("Unsupported Decoder type: \(typeName)")
}
}
Expand Down Expand Up @@ -128,7 +133,7 @@ class ByteFallbackDecoder: Decoder {
func decode(tokens: [String]) -> [String] {
var newTokens: [String] = []
var byteTokens: [Int] = []

func parseByte(_ token: String) -> Int? {
guard token.count == 6 && token.hasPrefix("<0x") && token.hasSuffix(">") else {
return nil
Expand Down Expand Up @@ -192,7 +197,7 @@ class MetaspaceDecoder: Decoder {
addPrefixSpace = config.addPrefixSpace?.boolValue ?? false
replacement = config.replacement?.stringValue ?? "_"
}

func decode(tokens: [String]) -> [String] {
var replaced = tokens.map { token in
token.replacingOccurrences(of: replacement, with: " ")
Expand All @@ -204,6 +209,38 @@ class MetaspaceDecoder: Decoder {
}
}

class WordPieceDecoder: Decoder {
let prefix: String
let cleanup: Bool

required public init(config: Config) {
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
self.prefix = prefix
self.cleanup = config.cleanup?.boolValue ?? false
}

func decode(tokens: [String]) -> [String] {
return tokens.enumerated().map { index, token in
var decodedToken = token
if index != 0 {
if decodedToken.hasPrefix(self.prefix) {
decodedToken = String(decodedToken.dropFirst(self.prefix.count))
} else {
decodedToken = " " + decodedToken
}
}
if self.cleanup {
decodedToken = cleanUpTokenization(decodedToken)
}
return decodedToken
}
}

private func cleanUpTokenization(_ token: String) -> String {
return token.trimmingCharacters(in: .whitespacesAndNewlines)
}
}

// We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once)
public extension String {
func trimmingFromStart(character: Character = " ", upto: Int) -> String {
Expand All @@ -215,7 +252,7 @@ public extension String {
}
return result
}

func trimmingFromEnd(character: Character = " ", upto: Int) -> String {
var result = self
var trimmed = 0
Expand Down
6 changes: 5 additions & 1 deletion Sources/Tokenizers/Normalizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ enum NormalizerType: String {
struct NormalizerFactory {
static func fromConfig(config: Config?) -> Normalizer? {
guard let config = config else { return nil }
guard let typeName = config.type?.stringValue else { return nil }
guard var typeName = config.type?.stringValue else { return nil }
if typeName.hasSuffix("Normalizer") {
typeName = String(typeName.dropLast("Normalizer".count))
}

let type = NormalizerType(rawValue: typeName)
switch type {
case .Sequence: return NormalizerSequence(config: config)
Expand Down
40 changes: 26 additions & 14 deletions Sources/Tokenizers/PostProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,30 @@ extension PostProcessor {
}

enum PostProcessorType: String {
case TemplateProcessing
case Template
case ByteLevel
case RobertaProcessing
case Bert
case Roberta

static let BertProcessing = "Bert"
static let RobertaProcessing = "Roberta"
static let TemplateProcessing = "Template"
Comment on lines -25 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is worth doing, I find the aliases and the suffix removal distracting for little benefit. My original approach was to simply use the same names that appear in the json, so someone reading both could easily match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very hard to use JSON as the ground truth, as they cone in all shapes and sizes.

Shorter name was needed for our models to work, but I've added the static variables for backward-compatibility, to avoid breaking the library for other users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say the json names used in transformers should be quite stable right now. Why do your models require shorter names? (just curious)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess no model "requires" specific names, and JSONs can always be changed, but it generally results in a snowballing set of changes that have to be applied on every platform...

In UForm I have identical tests that run for the same models on the same data across all 3 languages across ONNX, PyTorch, and CoreML. You can check them here.

If a certain behavior is standard in the more popular ports of the library (Python and JS), I assume Hugging Face may want to provide the same behavior here to encourage adoption. A lot of people would probably appreciate the portability 🤗

}

struct PostProcessorFactory {
static func fromConfig(config: Config?) -> PostProcessor? {
guard let config = config else { return nil }
guard let typeName = config.type?.stringValue else { return nil }
guard var typeName = config.type?.stringValue else { return nil }
if typeName.hasSuffix("Processing") {
typeName = String(typeName.dropLast("Processing".count))
}

let type = PostProcessorType(rawValue: typeName)
switch type {
case .TemplateProcessing: return TemplateProcessing(config: config)
case .Template : return TemplateProcessing(config: config)
case .ByteLevel : return ByteLevelPostProcessor(config: config)
case .RobertaProcessing : return RobertaProcessing(config: config)
case .Bert : return BertProcessing(config: config)
case .Roberta : return RobertaProcessing(config: config)
default : fatalError("Unsupported PostProcessor type: \(typeName)")
}
}
Expand All @@ -55,7 +65,7 @@ class TemplateProcessing: PostProcessor {

func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] {
let config = tokensPair == nil ? single : pair

var toReturn: [String] = []
for item in config {
if let specialToken = item.SpecialToken {
Expand All @@ -77,14 +87,14 @@ class ByteLevelPostProcessor: PostProcessor {
func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] { tokens }
}

class RobertaProcessing: PostProcessor {
class BertProcessing: PostProcessor {
private let sep: (UInt, String)
private let cls: (UInt, String)
/// Trim all remaining space, or leave one space character if `addPrefixSpace` is `true`.
private let trimOffset: Bool
/// Keep one space character on each side. Depends on `trimOffsets` being `true`.
private let addPrefixSpace: Bool

required public init(config: Config) {
guard let sep = config.sep?.tokenValue else { fatalError("Missing `sep` processor configuration") }
guard let cls = config.cls?.tokenValue else { fatalError("Missing `cls` processor configuration") }
Expand All @@ -101,22 +111,22 @@ class RobertaProcessing: PostProcessor {
if addPrefixSpace {
outTokens = outTokens.map({ trimExtraSpaces(token: $0) })
tokensPair = tokensPair?.map({ trimExtraSpaces(token: $0) })
} else {
} else {
outTokens = outTokens.map({ $0.trimmingCharacters(in: .whitespaces) })
tokensPair = tokensPair?.map({ $0.trimmingCharacters(in: .whitespaces) })
}
}

outTokens = [self.cls.1] + outTokens + [self.sep.1]
if let tokensPair = tokensPair, !tokensPair.isEmpty {
// Yes, it adds another `sep`.
// https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/roberta/hub_interface.py#L58-L65
outTokens += [self.sep.1] + tokensPair + [self.sep.1]
}

return outTokens
}

/// Some tokens need one space around them
/// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L203-L235
private func trimExtraSpaces(token: String) -> String {
Expand All @@ -126,14 +136,16 @@ class RobertaProcessing: PostProcessor {
let suffixIndex = token.index(token.startIndex, offsetBy: token.count - suffixOffset)
return String(token[prefixIndex..<suffixIndex])
}

private func findPrefixIndex(text: String) -> Int {
guard !text.isEmpty, text.first!.isWhitespace else { return 0 }
return text.prefix(while: { $0.isWhitespace }).count - 1
}

Comment on lines -134 to +144
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could keep the empty lines empty that'd be awesome. Otherwise no big deal, we can address all those issues in the style PR.

private func findSuffixIndex(text: String) -> Int {
guard !text.isEmpty, text.last!.isWhitespace else { return 0 }
return text.reversed().prefix(while: { $0.isWhitespace }).count - 1
}
}

class RobertaProcessing: BertProcessing { }
22 changes: 21 additions & 1 deletion Sources/Tokenizers/PreTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum PreTokenizerType: String {
case Sequence
case ByteLevel
case Punctuation
case Bert
case Digits
case Split
case Whitespace
Expand All @@ -48,12 +49,17 @@ enum PreTokenizerType: String {
struct PreTokenizerFactory {
static func fromConfig(config: Config?) -> PreTokenizer? {
guard let config = config else { return nil }
guard let typeName = config.type?.stringValue else { return nil }
guard var typeName = config.type?.stringValue else { return nil }
if typeName.hasSuffix("PreTokenizer") {
typeName = String(typeName.dropLast("PreTokenizer".count))
}

let type = PreTokenizerType(rawValue: typeName)
switch type {
case .Sequence : return PreTokenizerSequence(config: config)
case .ByteLevel: return ByteLevelPreTokenizer(config: config)
case .Punctuation: return PunctuationPreTokenizer(config: config)
case .Bert: return BertPreTokenizer(config: config)
case .Digits: return DigitsPreTokenizer(config: config)
case .Split: return SplitPreTokenizer(config: config)
case .Whitespace, .WhitespaceSplit: return WhitespacePreTokenizer(config: config)
Expand Down Expand Up @@ -192,6 +198,20 @@ class PunctuationPreTokenizer: PreTokenizer {
}
}

class BertPreTokenizer: PreTokenizer {
// Identical to PunctuationPreTokenizer, but with a different regex
let PUNCTUATION_REGEX = #"\p{P}\u0021-\u002F\u003A-\u0040\u005B-\u0060\u007B-\u007E"#
let re: String

required init(config: Config) {
re = "[^\\s\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]"
}

func preTokenize(text: String) -> [String] {
return text.ranges(of: re).map { String(text[$0]) }
}
}

class DigitsPreTokenizer: PreTokenizer {
let re: String

Expand Down
49 changes: 31 additions & 18 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,35 +59,48 @@ public protocol PreTrainedTokenizerModel: TokenizingModel {

struct TokenizerModel {
static let knownTokenizers: [String : PreTrainedTokenizerModel.Type] = [
"BertTokenizer" : BertTokenizer.self,
"CodeGenTokenizer" : CodeGenTokenizer.self,
"CodeLlamaTokenizer" : CodeLlamaTokenizer.self,
"FalconTokenizer" : FalconTokenizer.self,
"GemmaTokenizer" : GemmaTokenizer.self,
"GPT2Tokenizer" : GPT2Tokenizer.self,
"LlamaTokenizer" : LlamaTokenizer.self,
"T5Tokenizer" : T5Tokenizer.self,
"WhisperTokenizer" : WhisperTokenizer.self,
"CohereTokenizer" : CohereTokenizer.self,
"PreTrainedTokenizer": BPETokenizer.self
"Bert" : BertTokenizer.self,
"CodeGen" : CodeGenTokenizer.self,
"CodeLlama" : CodeLlamaTokenizer.self,
"Falcon" : FalconTokenizer.self,
"Gemma" : GemmaTokenizer.self,
"GPT2" : GPT2Tokenizer.self,
"Llama" : LlamaTokenizer.self,
"Unigram" : UnigramTokenizer.self,
"T5" : T5Tokenizer.self,
"Whisper" : WhisperTokenizer.self,
"Cohere" : CohereTokenizer.self,
"PreTrained": BPETokenizer.self
Comment on lines -62 to +73
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think I'd rather keep the same names if possible. If I search in the project for "PreTrainedTokenizer" I'd like to see this entry.

]

static func unknownToken(from tokenizerConfig: Config) -> String? {
return tokenizerConfig.unkToken?.content?.stringValue ?? tokenizerConfig.unkToken?.stringValue
}

public static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws -> TokenizingModel {
guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else {
throw TokenizerError.missingTokenizerClassInConfig
}

// Some tokenizer_class entries use a Fast suffix
let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
guard let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] else {
throw TokenizerError.unsupportedTokenizer(tokenizerName)
var tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
if tokenizerName.hasSuffix("Tokenizer") {
tokenizerName = String(tokenizerName.dropLast("Tokenizer".count))
}

return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
// Try to perform a direct case-sensitive lookup first
if let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] {
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
} else {
// If the direct lookup fails, perform a case-insensitive scan over the keys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is where we may want to drop the Tokenizer suffix, in my opinion.

if let key = TokenizerModel.knownTokenizers.keys.first(where: { $0.lowercased() == tokenizerName.lowercased() }) {
if let tokenizerClass = TokenizerModel.knownTokenizers[key] {
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}
}
}

throw TokenizerError.unsupportedTokenizer(tokenizerName)
}
}

Expand Down
Loading
Loading