diff --git a/.github/workflows/development-tests.yml b/.github/workflows/development-tests.yml index ae84251..78d6e71 100644 --- a/.github/workflows/development-tests.yml +++ b/.github/workflows/development-tests.yml @@ -2,7 +2,6 @@ name: Development Tests on: pull_request: - branches: ["main"] pull_request_review: types: [submitted] workflow_dispatch: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 557ab22..6853b4b 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -75,8 +75,19 @@ jobs: sleep 15 xcrun simctl list devices - name: Build and Test - ${{ matrix.run-config['name'] }} + id: test-step if: ${{ matrix.run-config['condition'] == true }} + continue-on-error: true run: | set -o pipefail xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' + + - name: Upload Test Results + if: failure() && steps.test-step.outcome == 'failure' + uses: actions/upload-artifact@v4 + with: + name: test-results-${{ matrix.run-config['name'] }} + path: | + ~/Library/Developer/Xcode/DerivedData/**/Logs/Test/*.xcresult + retention-days: 5 diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index fb44b79..689d07b 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -93,7 +93,7 @@ public extension AudioProcessing { } static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { - guard startIndex >= 0 && startIndex < audioArray.count else { + guard startIndex >= 0, startIndex < audioArray.count else { Logging.error("startIndex is outside the buffer size") return nil } @@ -228,7 +228,11 @@ public class AudioProcessor: NSObject, AudioProcessing { guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameCount) else { throw WhisperError.loadAudioFailed("Unable to create audio buffer") } - try audioFile.read(into: buffer, frameCount: frameCount) + do { + try audioFile.read(into: buffer, frameCount: frameCount) + } catch { + throw WhisperError.loadAudioFailed("Failed to read audio file: \(error)") + } outputBuffer = buffer } else { // Audio needs resampling to 16khz diff --git a/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift b/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift index bb7ef62..ee17dc2 100644 --- a/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift +++ b/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift @@ -117,7 +117,7 @@ open class VoiceActivityDetector { } } - // MARK - Utility + // MARK: - Utility func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] { let nonSilentChunks = calculateActiveChunks(in: waveform) diff --git a/Sources/WhisperKit/Core/AudioEncoder.swift b/Sources/WhisperKit/Core/AudioEncoder.swift index 06337cd..9205a5c 100644 --- a/Sources/WhisperKit/Core/AudioEncoder.swift +++ b/Sources/WhisperKit/Core/AudioEncoder.swift @@ -3,17 +3,22 @@ import CoreML +public protocol AudioEncoderOutputType {} +extension MLMultiArray: AudioEncoderOutputType {} + /// AudioEncoding protocol defines the requirements for an audio encoding implementation. +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public protocol AudioEncoding { /// The size of the embedding produced by the encoder. var embedSize: Int? { get } /// Encodes the given audio features asynchronously. /// - Parameter features: The audio features to be encoded. - /// - Returns: An optional `MLMultiArray` containing the encoded features. - func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? + /// - Returns: An optional tensor containing the encoded features. + func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)? } +/// Backwards-compatible AudioEncoder implementation @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public class AudioEncoder: AudioEncoding, WhisperMLModel { public var model: MLModel? @@ -36,8 +41,15 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel { public init() {} + public func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)? { + guard let features = features as? MLMultiArray else { + throw WhisperError.audioProcessingFailed("AudioEncoder input must be MLMultiArray") + } + + return try await encodeFeatures(features) + } + public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? { - // Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000) guard let model else { throw WhisperError.modelsUnavailable() } diff --git a/Sources/WhisperKit/Core/FeatureExtractor.swift b/Sources/WhisperKit/Core/FeatureExtractor.swift index 6569809..0fb0f68 100644 --- a/Sources/WhisperKit/Core/FeatureExtractor.swift +++ b/Sources/WhisperKit/Core/FeatureExtractor.swift @@ -7,10 +7,15 @@ import CoreGraphics import CoreML import Foundation +public protocol FeatureExtractorOutputType {} +extension MLMultiArray: FeatureExtractorOutputType {} + public protocol FeatureExtracting { + associatedtype OutputType: FeatureExtractorOutputType + var melCount: Int? { get } var windowSamples: Int? { get } - func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? + func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> OutputType? } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 8f204e9..bf7d7eb 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -299,7 +299,7 @@ public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable { } } -public struct DecodingInputs { +open class DecodingInputs { public var initialPrompt: [Int] public var inputIds: MLMultiArray public var cacheLength: MLMultiArray @@ -580,6 +580,7 @@ public struct TranscriptionResult: Codable { Total Tokens: \(totalTokens) Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s Real Time Factor: \(String(format: "%.3f", rtf)) + Speed Factor: \(String(format: "%.3f", 1.0 / rtf)) Fallbacks: \(timings.totalDecodingFallbacks) """) } diff --git a/Sources/WhisperKit/Core/Text/LogitsFilter.swift b/Sources/WhisperKit/Core/Text/LogitsFilter.swift index 28218de..174bb50 100644 --- a/Sources/WhisperKit/Core/Text/LogitsFilter.swift +++ b/Sources/WhisperKit/Core/Text/LogitsFilter.swift @@ -75,8 +75,9 @@ open class TimestampRulesFilter: LogitsFiltering { public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray { guard let sampleBegin = sampleBegin(for: tokens), - sampleBegin > tokens.count + sampleBegin <= tokens.count else { + // Early return if we are still prefilling the prompt return logits } diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index ce15cd5..3657268 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -29,129 +29,179 @@ open class GreedyTokenSampler: TokenSampling { } public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { - var softmaxOutput: BNNSNDArrayDescriptor? - var argmaxOutput: BNNSNDArrayDescriptor? - var softmaxInput: BNNSNDArrayDescriptor? - var softmaxInputNeedsDeallocate = false + var nextTokens = tokens + var nextLogprobs = logProbs + var completed = false + if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) { + // Use MLTensor operations if available for sampling + // Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift + var logitsTensor = MLTensor(MLShapedArray(logits)).cast(to: Float.self) + var nextTokenTensor: MLTensor + var nextLogprobTensor: MLTensor - var nextToken: Int? - - do { - let logitsRawPointer = UnsafeMutableRawBufferPointer( - start: logits.dataPointer, - count: logits.count * MemoryLayout.stride - ) - - let logitsDescriptor = BNNSNDArrayDescriptor( - data: logitsRawPointer, - scalarType: FloatType.self, - shape: .vector(logits.count, stride: 1) - )! + if temperature != 0.0 { + // Scale logits by temperature if > 0 + logitsTensor = logitsTensor / temperature + } - softmaxInput = logitsDescriptor + // Always softmax once + let softmaxScores = logitsTensor.softmax(alongAxis: -1) - // Scale logits by temperature if > 0 if temperature != 0.0 { - let scaledLogits = BNNSNDArrayDescriptor.allocateUninitialized( - scalarType: FloatType.self, - shape: .vector(logits.count, stride: 1) - ) + // top-k multinomial sampling + let (topKProbs, topKIndices) = softmaxScores.topK(decodingOptions.topK) - try! BNNS.applyActivation( - activation: BNNS.ActivationFunction.linear(alpha: Float(1 / temperature)), - input: logitsDescriptor, - output: scaledLogits, - batchSize: 1 - ) + let rnd = topKProbs.sum() * Float.random(in: 0..<1) + var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1) + accumTopKProbs += (accumTopKProbs .< rnd) * 100.0 + let topKIndex = accumTopKProbs.argsort()[..., 0] - softmaxInput = scaledLogits - softmaxInputNeedsDeallocate = true + nextTokenTensor = topKIndices.gathering( + atIndices: topKIndex, + alongAxis: topKIndices.rank - 1 + ) + nextLogprobTensor = topKProbs.gathering( + atIndices: topKIndex, + alongAxis: topKIndices.rank - 1 + ).log() + } else { + nextTokenTensor = logitsTensor.argmax(alongAxis: -1) + nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log() } - // Always softmax once - softmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized( - scalarType: Float.self, - shape: .vector(logits.count, stride: 1) - ) - - try BNNS.applyActivation( - activation: BNNS.ActivationFunction.softmax, - input: softmaxInput!, - output: softmaxOutput!, - batchSize: 1 - ) + let nextToken = nextTokenTensor.asIntArray()[0] + let nextLogprob = nextLogprobTensor.asFloatArray()[0] - if temperature != 0.0 { - // top-k multinomial sampling - let k = decodingOptions.topK + nextTokens = tokens + [nextToken] + nextLogprobs = logProbs + [nextLogprob] + completed = nextToken == eotToken - let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float.self, shape: .vector(k, stride: 1)) - let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int32.self, shape: .vector(k, stride: 1)) + } else { + // TODO: BNNS operations here are deprecated, replace with vDSP or MLX + var softmaxOutput: BNNSNDArrayDescriptor? + var argmaxOutput: BNNSNDArrayDescriptor? + var softmaxInput: BNNSNDArrayDescriptor? + var softmaxInputNeedsDeallocate = false - try! BNNS.applyTopK( - k: k, - input: softmaxOutput!, - bestValues: bestValues, - bestIndices: bestIndices, - axis: 0, - batchSize: 1 + var nextToken: Int? + + do { + let logitsRawPointer = UnsafeMutableRawBufferPointer( + start: logits.dataPointer, + count: logits.count * MemoryLayout.stride ) - let bestValuesResult = bestValues.makeArray(of: Float.self)! - let bestIndicesResult = bestIndices.makeArray(of: Int32.self)! - - bestValues.deallocate() - bestIndices.deallocate() - - // multinomial sample from top-k - let sumOfbestIndicesResult = bestValuesResult.reduce(0, +) - let rnd = Float.random(in: 0.. 0 + if temperature != 0.0 { + let scaledLogits = BNNSNDArrayDescriptor.allocateUninitialized( + scalarType: FloatType.self, + shape: .vector(logits.count, stride: 1) + ) + + try! BNNS.applyActivation( + activation: BNNS.ActivationFunction.linear(alpha: Float(1 / temperature)), + input: logitsDescriptor, + output: scaledLogits, + batchSize: 1 + ) + + softmaxInput = scaledLogits + softmaxInputNeedsDeallocate = true } - nextToken = Int(bestIndicesResult[chosenIndex]) - } else { - // Argmax sampling - argmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized( + // Always softmax once + softmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized( scalarType: Float.self, - shape: .vector(1, stride: 1) + shape: .vector(logits.count, stride: 1) ) - try! BNNS.applyReduction( - BNNS.ReductionFunction.argMax, - input: logitsDescriptor, - output: argmaxOutput!, - weights: nil + try BNNS.applyActivation( + activation: BNNS.ActivationFunction.softmax, + input: softmaxInput!, + output: softmaxOutput!, + batchSize: 1 ) - let argmaxResult = argmaxOutput!.makeArray(of: Float.self)! + if temperature != 0.0 { + // top-k multinomial sampling + let k = decodingOptions.topK + + let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float.self, shape: .vector(k, stride: 1)) + let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int32.self, shape: .vector(k, stride: 1)) + + try! BNNS.applyTopK( + k: k, + input: softmaxOutput!, + bestValues: bestValues, + bestIndices: bestIndices, + axis: 0, + batchSize: 1 + ) + + let bestValuesResult = bestValues.makeArray(of: Float.self)! + let bestIndicesResult = bestIndices.makeArray(of: Int32.self)! + + bestValues.deallocate() + bestIndices.deallocate() + + // multinomial sample from top-k + let sumOfbestIndicesResult = bestValuesResult.reduce(0, +) + let rnd = Float.random(in: 0.. (logits: MLMultiArray?, cache: DecodingCache?)? + _ inputs: any TextDecoderInputType + ) async throws -> TextDecoderOutputType? func prefillKVCache( withTask task: MLMultiArray, @@ -33,7 +71,7 @@ public protocol TextDecoding { ) async throws -> DecodingCache? func decodeText( - from encoderOutput: MLMultiArray, + from encoderOutput: any AudioEncoderOutputType, using decoderInputs: DecodingInputs, sampler tokenSampler: TokenSampling, options decoderOptions: DecodingOptions, @@ -51,7 +89,7 @@ public protocol TextDecoding { ) async throws -> [DecodingResult] func detectLanguage( - from encoderOutput: MLMultiArray, + from encoderOutput: any AudioEncoderOutputType, using decoderInputs: DecodingInputs, sampler tokenSampler: TokenSampling, options: DecodingOptions, @@ -310,6 +348,32 @@ public extension TextDecoding { } } + static func updateAlignmentWeights( + alignmentTensor: MLMultiArray, + alignmentSlice: MLMultiArray, + insertAtIndex tokenIndex: Int + ) { + let tensorShape = alignmentTensor.shape.map { $0.intValue } + let sliceStrides = alignmentSlice.strides.map { $0.intValue } + let bytesPerSample = MemoryLayout.size + + alignmentTensor.withUnsafeMutableBytes { alignmentPointer, alignmentStrides in + alignmentSlice.withUnsafeBytes { slicePointer in + // Process each column + for column in 0.. TextDecoderOutputType? { + guard let inputs = inputs as? TextDecoderMLMultiArrayInputType else { + throw WhisperError.transcriptionFailed("Input must be TextDecoderMLMultiArrayInputType") + } + + let result = try await predictLogits( + inputIds: inputs.inputIds, + cacheLength: inputs.cacheLength, + keyCache: inputs.keyCache, + valueCache: inputs.valueCache, + kvCacheUpdateMask: inputs.kvCacheUpdateMask, + encoderOutputEmbeds: inputs.encoderOutputEmbeds, + decoderKeyPaddingMask: inputs.decoderKeyPaddingMask + ) + + return TextDecoderMLMultiArrayOutputType(logits: result?.logits, cache: result?.cache) + } + public func predictLogits( inputIds: MLMultiArray, cacheLength: MLMultiArray, @@ -389,6 +473,10 @@ open class TextDecoder: TextDecoding, WhisperMLModel { encoderOutputEmbeds: MLMultiArray, decoderKeyPaddingMask: MLMultiArray ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? { + guard let model = model else { + return nil + } + let modelInputs = TextDecoderInput( input_ids: inputIds, cache_length: cacheLength, @@ -399,10 +487,6 @@ open class TextDecoder: TextDecoding, WhisperMLModel { decoder_key_padding_mask: decoderKeyPaddingMask ) - guard let model = model else { - return nil - } - try Task.checkCancellation() let outputFeatures = try await model.asyncPrediction(from: modelInputs, options: MLPredictionOptions()) @@ -420,7 +504,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { } public func detectLanguage( - from encoderOutput: MLMultiArray, + from encoderOutput: any AudioEncoderOutputType, using decoderInputs: DecodingInputs, sampler tokenSampler: TokenSampling, options: DecodingOptions, @@ -464,15 +548,20 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let inferenceTime = Date() Logging.debug("Detecting language...") + guard let encoderOutput = encoderOutput as? MLMultiArray else { + throw WhisperError.prepareDecoderInputsFailed("Input must be MLMultiArray") + } let predictedLogits = try await self.predictLogits( - inputIds: decoderInputs.inputIds, - cacheLength: decoderInputs.cacheLength, - keyCache: decoderInputs.keyCache, - valueCache: decoderInputs.valueCache, - kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, - encoderOutputEmbeds: encoderOutput, - decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask - ) + TextDecoderMLMultiArrayInputType( + inputIds: decoderInputs.inputIds, + cacheLength: decoderInputs.cacheLength, + keyCache: decoderInputs.keyCache, + valueCache: decoderInputs.valueCache, + kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, + encoderOutputEmbeds: encoderOutput, + decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask + ) + ) as? TextDecoderMLMultiArrayOutputType guard let decoderOutput = predictedLogits else { Logging.error("Unable to decode logits") @@ -533,7 +622,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { } public func decodeText( - from encoderOutput: MLMultiArray, + from encoderOutput: any AudioEncoderOutputType, using decoderInputs: DecodingInputs, sampler tokenSampler: TokenSampling, options: DecodingOptions, @@ -591,12 +680,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel { var hasAlignment = false var isFirstTokenLogProbTooLow = false let windowUUID = UUID() - Task { [weak self] in - guard let self = self else { return } - await MainActor.run { - self.shouldEarlyStop[windowUUID] = false - } - } + await earlyStopActor.set(false, for: windowUUID) + for tokenIndex in prefilledIndex.. DecodingResult { diff --git a/Sources/WhisperKit/Core/Utils/Concurrency.swift b/Sources/WhisperKit/Core/Utils/Concurrency.swift new file mode 100644 index 0000000..58c94d6 --- /dev/null +++ b/Sources/WhisperKit/Core/Utils/Concurrency.swift @@ -0,0 +1,34 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation + +/// An actor that provides thread-safe early stopping functionality using UUIDs as keys +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +public actor EarlyStopActor { + private var shouldStop = [UUID: Bool]() + + public init() {} + + /// Sets the stop flag for a given UUID + /// - Parameters: + /// - value: The boolean value to set + /// - uuid: The UUID key + public func set(_ value: Bool, for uuid: UUID) { + shouldStop[uuid] = value + } + + /// Gets the stop flag for a given UUID + /// - Parameter uuid: The UUID key + /// - Returns: The current stop flag value, or false if not set + public func get(for uuid: UUID) -> Bool { + return shouldStop[uuid] ?? false + } + + /// Removes and returns the stop flag for a given UUID + /// - Parameter uuid: The UUID key + /// - Returns: The removed stop flag value, if it existed + public func remove(for uuid: UUID) -> Bool? { + return shouldStop.removeValue(forKey: uuid) + } +} diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils/Utils.swift similarity index 93% rename from Sources/WhisperKit/Core/Utils.swift rename to Sources/WhisperKit/Core/Utils/Utils.swift index 77824e4..0a923be 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils/Utils.swift @@ -109,6 +109,74 @@ extension MLMultiArray { } } +@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) +public extension MLTensor { + func asIntArray() -> [Int] { + let semaphore = DispatchSemaphore(value: 0) + var result: [Int] = [] + + Task(priority: .high) { + result = await self.shapedArray(of: Int32.self).scalars.map { Int($0) } + semaphore.signal() + } + + semaphore.wait() + return result + } + + func asFloatArray() -> [Float] { + let semaphore = DispatchSemaphore(value: 0) + let tensorType = self.scalarType + + var result: [Float] = [] + + Task(priority: .high) { + switch tensorType { + case is Float32.Type: + result = await self.shapedArray(of: Float32.self).scalars.map { Float($0) } + case is FloatType.Type: + result = await self.shapedArray(of: FloatType.self).scalars.map { Float($0) } + case is Float.Type: + result = await self.shapedArray(of: Float.self).scalars.map { Float($0) } + case is Int32.Type: + result = await self.shapedArray(of: Int32.self).scalars.map { Float($0) } + default: + fatalError("Unsupported data type") + } + semaphore.signal() + } + + semaphore.wait() + return result + } + + func asMLMultiArray() -> MLMultiArray { + let semaphore = DispatchSemaphore(value: 0) + let tensorType = self.scalarType + + var result: MLMultiArray = initMLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0) + + Task(priority: .high) { + switch tensorType { + case is Float32.Type: + result = MLMultiArray(await self.shapedArray(of: Float32.self)) + case is FloatType.Type: + result = MLMultiArray(await self.shapedArray(of: FloatType.self)) + case is Float.Type: + result = MLMultiArray(await self.shapedArray(of: Float.self)) + case is Int32.Type: + result = MLMultiArray(await self.shapedArray(of: Int32.self)) + default: + fatalError("Unsupported data type") + } + semaphore.signal() + } + + semaphore.wait() + return result + } +} + extension MLModel { func asyncPrediction( from input: MLFeatureProvider, diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index f782c98..6cccf01 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -151,14 +151,22 @@ open class WhisperKit { return modelSupport(for: deviceName) } - public static func recommendedRemoteModels(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupport { + public static func recommendedRemoteModels( + from repo: String = "argmaxinc/whisperkit-coreml", + downloadBase: URL? = nil, + token: String? = nil + ) async -> ModelSupport { let deviceName = Self.deviceName() - let config = await Self.fetchModelSupportConfig(from: repo, downloadBase: downloadBase) + let config = await Self.fetchModelSupportConfig(from: repo, downloadBase: downloadBase, token: token) return modelSupport(for: deviceName, from: config) } - public static func fetchModelSupportConfig(from repo: String = "argmaxinc/whisperkit-coreml", downloadBase: URL? = nil) async -> ModelSupportConfig { - let hubApi = HubApi(downloadBase: downloadBase) + public static func fetchModelSupportConfig( + from repo: String = "argmaxinc/whisperkit-coreml", + downloadBase: URL? = nil, + token: String? = nil + ) async -> ModelSupportConfig { + let hubApi = HubApi(downloadBase: downloadBase, hfToken: token) var modelSupportConfig = Constants.fallbackModelSupportConfig do { @@ -175,8 +183,13 @@ open class WhisperKit { return modelSupportConfig } - public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["*"], downloadBase: URL? = nil) async throws -> [String] { - let modelSupportConfig = await fetchModelSupportConfig(from: repo, downloadBase: downloadBase) + public static func fetchAvailableModels( + from repo: String = "argmaxinc/whisperkit-coreml", + matching: [String] = ["*"], + downloadBase: URL? = nil, + token: String? = nil + ) async throws -> [String] { + let modelSupportConfig = await fetchModelSupportConfig(from: repo, downloadBase: downloadBase, token: token) let supportedModels = modelSupportConfig.modelSupport().supported var filteredSupportSet: Set = [] for glob in matching { @@ -228,9 +241,10 @@ open class WhisperKit { downloadBase: URL? = nil, useBackgroundSession: Bool = false, from repo: String = "argmaxinc/whisperkit-coreml", + token: String? = nil, progressCallback: ((Progress) -> Void)? = nil ) async throws -> URL { - let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession) + let hubApi = HubApi(downloadBase: downloadBase, hfToken: token, useBackgroundSession: useBackgroundSession) let repo = Hub.Repo(id: repo, type: .models) let modelSearchPath = "*\(variant.description)/*" do { diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index b19b213..62fcd72 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -342,6 +342,202 @@ final class UnitTests: XCTestCase { XCTAssertGreaterThan(energyVeryLoud, energyLoud, "Audio energy is not very loud") } + // MARK: - Protocol Conformance Tests + + func testMLMultiArrayConformsToFeatureExtractorOutputType() { + let array = try! MLMultiArray(shape: [1], dataType: .float16) + XCTAssertNotNil(array as FeatureExtractorOutputType) + } + + func testMLMultiArrayConformsToAudioEncoderOutputType() { + let array = try! MLMultiArray(shape: [1], dataType: .float16) + XCTAssertNotNil(array as AudioEncoderOutputType) + } + + func testMLMultiArrayConformsToTextDecoderTensorType() { + let array = try! MLMultiArray(shape: [1], dataType: .float16) + XCTAssertNotNil(array as TextDecoderTensorType) + } + + // MARK: - Generic Type Tests + + func testEncodeFeatureWithGenericType() async throws { + let audioEncoder = AudioEncoder() + let modelPath = try URL(filePath: tinyModelPath()).appending(path: "AudioEncoder.mlmodelc") + try await audioEncoder.loadModel(at: modelPath, computeUnits: .cpuAndNeuralEngine) + + // Create a test input that conforms to FeatureExtractorOutputType + let input = try MLMultiArray(shape: [1, 80, 1, 3000], dataType: .float16) + + // Test encoding with generic type + let output = try await audioEncoder.encodeFeatures(input) + + XCTAssertNotNil(output) + XCTAssertNotNil(output! as AudioEncoderOutputType) + + // Test specific shape of output + if let mlOutput = output { + XCTAssertEqual(mlOutput.shape, [1, 384, 1, 1500]) + } else { + XCTFail("Output should be MLMultiArray") + } + } + + func testEncodeFeatureWithInvalidType() async throws { + let audioEncoder = AudioEncoder() + let modelPath = try URL(filePath: tinyModelPath()).appending(path: "AudioEncoder.mlmodelc") + try await audioEncoder.loadModel(at: modelPath, computeUnits: .cpuAndNeuralEngine) + + // Create an invalid input type + struct InvalidType: FeatureExtractorOutputType {} + let invalidInput = InvalidType() + + // Test that encoding fails with invalid type + do { + _ = try await audioEncoder.encodeFeatures(invalidInput) + XCTFail("Should throw error for invalid input type") + } catch let WhisperError.audioProcessingFailed(message) { + XCTAssertEqual(message, "AudioEncoder input must be MLMultiArray") + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + // MARK: - TextDecoder Generic Type Tests + + func testPredictLogitsWithGenericType() async throws { + let textDecoder = TextDecoder() + let modelPath = try URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc") + try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute) + + // Create test inputs + let input = try TextDecoderMLMultiArrayInputType( + inputIds: MLMultiArray(shape: [1], dataType: .int32), + cacheLength: MLMultiArray(shape: [1], dataType: .int32), + keyCache: MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16), + valueCache: MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16), + kvCacheUpdateMask: MLMultiArray(shape: [1, 224], dataType: .float16), + encoderOutputEmbeds: MLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16), + decoderKeyPaddingMask: MLMultiArray(shape: [1, 224], dataType: .float16) + ) + + // Test prediction with generic type + let output = try await textDecoder.predictLogits(input) + + XCTAssertNotNil(output) + XCTAssertNotNil(output as? TextDecoderMLMultiArrayOutputType) + } + + func testPredictLogitsWithInvalidType() async throws { + let textDecoder = TextDecoder() + let modelPath = try URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc") + try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute) + + // Create an invalid input type + struct InvalidType: TextDecoderInputType {} + let invalidInput = InvalidType() + + // Test that prediction fails with invalid type + do { + _ = try await textDecoder.predictLogits(invalidInput) + XCTFail("Should throw error for invalid input type") + } catch let WhisperError.transcriptionFailed(message) { + XCTAssertEqual(message, "Input must be TextDecoderMLMultiArrayInputType") + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testTextDecoderMLMultiArrayInputType() { + let inputIds = try! MLMultiArray(shape: [1], dataType: .int32) + let cacheLength = try! MLMultiArray(shape: [1], dataType: .int32) + let keyCache = try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16) + let valueCache = try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16) + let kvCacheUpdateMask = try! MLMultiArray(shape: [1, 224], dataType: .float16) + let encoderOutputEmbeds = try! MLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16) + let decoderKeyPaddingMask = try! MLMultiArray(shape: [1, 224], dataType: .float16) + + let input = TextDecoderMLMultiArrayInputType( + inputIds: inputIds, + cacheLength: cacheLength, + keyCache: keyCache, + valueCache: valueCache, + kvCacheUpdateMask: kvCacheUpdateMask, + encoderOutputEmbeds: encoderOutputEmbeds, + decoderKeyPaddingMask: decoderKeyPaddingMask + ) + + XCTAssertNotNil(input as TextDecoderInputType) + XCTAssertEqual(input.inputIds.shape, [1]) + XCTAssertEqual(input.cacheLength.shape, [1]) + XCTAssertEqual(input.keyCache.shape, [1, 1536, 1, 224]) + XCTAssertEqual(input.valueCache.shape, [1, 1536, 1, 224]) + XCTAssertEqual(input.kvCacheUpdateMask.shape, [1, 224]) + XCTAssertEqual(input.encoderOutputEmbeds.shape, [1, 384, 1, 1500]) + XCTAssertEqual(input.decoderKeyPaddingMask.shape, [1, 224]) + } + + func testTextDecoderMLMultiArrayOutputType() { + let logits = try! MLMultiArray(shape: [1, 51865, 1, 1], dataType: .float16) + let cache = DecodingCache( + keyCache: try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16), + valueCache: try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16), + alignmentWeights: try! MLMultiArray(shape: [1, 224], dataType: .float16) + ) + + let output = TextDecoderMLMultiArrayOutputType(logits: logits, cache: cache) + + XCTAssertNotNil(output as TextDecoderOutputType) + XCTAssertEqual(output.logits?.shape, [1, 51865, 1, 1]) + XCTAssertNotNil(output.cache) + XCTAssertEqual(output.cache?.keyCache?.shape, [1, 1536, 1, 224]) + XCTAssertEqual(output.cache?.valueCache?.shape, [1, 1536, 1, 224]) + XCTAssertEqual(output.cache?.alignmentWeights?.shape, [1, 224]) + } + + func testTextDecoderMLMultiArrayOutputTypeWithNilValues() { + let output = TextDecoderMLMultiArrayOutputType() + + XCTAssertNotNil(output as TextDecoderOutputType) + XCTAssertNil(output.logits) + XCTAssertNil(output.cache) + } + + func testDecodingCacheInitialization() { + let keyCache = try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16) + let valueCache = try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16) + let alignmentWeights = try! MLMultiArray(shape: [1, 224], dataType: .float16) + + let cache = DecodingCache( + keyCache: keyCache, + valueCache: valueCache, + alignmentWeights: alignmentWeights + ) + + XCTAssertEqual(cache.keyCache?.shape, [1, 1536, 1, 224]) + XCTAssertEqual(cache.valueCache?.shape, [1, 1536, 1, 224]) + XCTAssertEqual(cache.alignmentWeights?.shape, [1, 224]) + } + + func testDecodingCacheWithNilValues() { + let cache = DecodingCache() + + XCTAssertNil(cache.keyCache) + XCTAssertNil(cache.valueCache) + XCTAssertNil(cache.alignmentWeights) + } + + func testDecodingCacheWithPartialValues() { + let keyCache = try! MLMultiArray(shape: [1, 1536, 1, 224], dataType: .float16) + + let cache = DecodingCache(keyCache: keyCache) + + XCTAssertNotNil(cache.keyCache) + XCTAssertNil(cache.valueCache) + XCTAssertNil(cache.alignmentWeights) + XCTAssertEqual(cache.keyCache?.shape, [1, 1536, 1, 224]) + } + // MARK: - Feature Extractor Tests func testLogmelOutput() async throws { @@ -561,8 +757,8 @@ final class UnitTests: XCTestCase { ) XCTAssertNotNil(result) - let tokenCount = result.segments.flatMap { $0.tokens }.count - let decodingTimePerToken = result.timings.decodingLoop / Double(tokenCount) + let tokenCountWithEarlyStop = result.segments.flatMap { $0.tokens }.count + let decodingTimePerTokenWithEarlyStop = result.timings.decodingLoop / Double(tokenCountWithEarlyStop) // Work done in the callback should not block the decoding loop let continuationCallbackWithWait: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in @@ -581,10 +777,10 @@ final class UnitTests: XCTestCase { Logging.debug("Decoding loop without wait: \(result.timings.decodingLoop), with wait: \(resultWithWait.timings.decodingLoop)") // Assert that the decoding predictions per token are not slower with the waiting - XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerToken, accuracy: decodingTimePerToken, "Decoding predictions per token should not be significantly slower with waiting") + XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerTokenWithEarlyStop, accuracy: decodingTimePerTokenWithEarlyStop, "Decoding predictions per token should not be significantly slower with waiting") // Assert that more tokens are returned in the callback with waiting - XCTAssertGreaterThan(tokenCountWithWait, tokenCount, "More tokens should be returned in the callback with waiting") + XCTAssertGreaterThan(tokenCountWithWait, tokenCountWithEarlyStop, "More tokens should be returned in the callback with waiting") } // MARK: - Tokenizer Tests @@ -662,7 +858,7 @@ final class UnitTests: XCTestCase { let transcribeResult: [TranscriptionResult] = try await whisperKit.transcribe(audioArray: multiWindowSamples, decodeOptions: options) let result = try XCTUnwrap(transcribeResult.first) - XCTAssertEqual(result.segments.count, 2, "Expected 3 segments") + XCTAssertEqual(result.segments.count, 3, "Expected 3 segments") // Compare last timestamp to the length of the audio let endTimestamp = try XCTUnwrap( @@ -1064,7 +1260,7 @@ final class UnitTests: XCTestCase { "Failed to transcribe" ) - XCTAssertEqual(result.segments.first?.text, " and so my fellow americans ask not what your country can do for you ask what you can do for your country.") + XCTAssertEqual(result.segments.first?.text, " and so my fellow americans ask not what your country can do for you ask what you can do for your country") } func testCallbacks() async throws { @@ -1223,91 +1419,76 @@ final class UnitTests: XCTestCase { } func testTimestampRulesFilter() throws { - // NOTE: for non-multilingual models we supress tokens immediately - let tokensFilter1 = TimestampRulesFilter( + // NOTE: for non-multilingual models we suppress tokens immediately + let tokensFilter = TimestampRulesFilter( specialTokens: .default( endToken: 3, noTimestampsToken: 2, - timeTokenBegin: 4, - transcribeToken: 100, - translateToken: 101 + timeTokenBegin: 6, + transcribeToken: 4, + translateToken: 5 ), - sampleBegin: 2, + sampleBegin: 0, maxInitialTimestampIndex: nil, isModelMultilingual: false ) - let logits1 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - let result1 = tokensFilter1.filterLogits(logits1, withTokens: []) - XCTAssertEqual(result1.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, 0.2]) - let tokensFilter2 = TimestampRulesFilter( - specialTokens: .default( - endToken: 3, - noTimestampsToken: 2, - timeTokenBegin: 4, - transcribeToken: 100, - translateToken: 101 - ), - sampleBegin: 2, - maxInitialTimestampIndex: nil, - isModelMultilingual: false - ) + // noTimestampToken should always be suppressed if tokens pass sampleBegin + let logits1 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result1 = tokensFilter.filterLogits(logits1, withTokens: [4]) + XCTAssertEqual(result1.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) - let logits2 = try MLMultiArray.logits([1.1, 0.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - let result2 = tokensFilter2.filterLogits(logits2, withTokens: []) - XCTAssertEqual(result2.data(for: 2), [-.infinity, -.infinity, -.infinity, -.infinity, 0.2, 0.1, 0.2]) + // Timestamps should not decrease (filters up to last seen timestamp) + let logits2 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result2 = tokensFilter.filterLogits(logits2, withTokens: [0, 6, 7, 3]) + XCTAssertEqual(result2.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, -.infinity, -.infinity, 0.1]) + + // If last two tokens are timestamps, filter all timestamps (allows text token to be next) + let logits3 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result3 = tokensFilter.filterLogits(logits3, withTokens: [0, 6, 7]) + XCTAssertEqual(result3.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, -.infinity, -.infinity, -.infinity]) + + // If only one previous token was a timestamp, filter all text and non-decreasing timestamps (to find matching timestamp pair) + let logits4 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result4 = tokensFilter.filterLogits(logits4, withTokens: [0, 4, 7]) + XCTAssertEqual(result4.data(for: 2), [-.infinity, -.infinity, -.infinity, -.infinity, -.infinity, -.infinity, -.infinity, 0.1, 0.1]) } func testTimestampRulesFilterMultilingual() throws { - // NOTE: for multilingual models we supress tokens only after transcribe or translate token - let tokensFilter1 = TimestampRulesFilter( + // NOTE: for multilingual models we suppress tokens only after transcribe or translate token + let tokensFilter = TimestampRulesFilter( specialTokens: .default( endToken: 3, noTimestampsToken: 2, - timeTokenBegin: 4, - transcribeToken: 100, - translateToken: 101 + timeTokenBegin: 6, + transcribeToken: 4, + translateToken: 5 ), - sampleBegin: 2, + sampleBegin: 0, maxInitialTimestampIndex: nil, isModelMultilingual: true ) - let logits1 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - let result1 = tokensFilter1.filterLogits(logits1, withTokens: []) - XCTAssertEqual(result1.data(for: 2), [1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - let tokensFilter2 = TimestampRulesFilter( - specialTokens: .default( - endToken: 3, - noTimestampsToken: 2, - timeTokenBegin: 4, - transcribeToken: 100, - translateToken: 101 - ), - sampleBegin: 2, - maxInitialTimestampIndex: nil, - isModelMultilingual: true - ) - let logits2 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - let result2 = tokensFilter2.filterLogits(logits2, withTokens: [100]) - XCTAssertEqual(result2.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, 0.2]) + // Without task token, nothing should be suppressed even with tokens past sampleBegin + let logits1 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result1 = tokensFilter.filterLogits(logits1, withTokens: [0, 1, 2]) + XCTAssertEqual(result1.data(for: 2), [1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) - let tokensFilter3 = TimestampRulesFilter( - specialTokens: .default( - endToken: 3, - noTimestampsToken: 2, - timeTokenBegin: 4, - transcribeToken: 100, - translateToken: 101 - ), - sampleBegin: 2, - maxInitialTimestampIndex: nil, - isModelMultilingual: true - ) - let logits3 = try MLMultiArray.logits([1.1, 0.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - let result3 = tokensFilter3.filterLogits(logits3, withTokens: [101]) - XCTAssertEqual(result3.data(for: 2), [-.infinity, -.infinity, -.infinity, -.infinity, 0.2, 0.1, 0.2]) + // Timestamps should not decrease after task token (filters up to last seen timestamp) + let logits2 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result2 = tokensFilter.filterLogits(logits2, withTokens: [0, 4, 6, 7, 3]) + XCTAssertEqual(result2.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, -.infinity, -.infinity, 0.1]) + + // If last two tokens after task are timestamps, filter all timestamps (allows text token to be next) + let logits3 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result3 = tokensFilter.filterLogits(logits3, withTokens: [0, 5, 6, 7]) + XCTAssertEqual(result3.data(for: 2), [1.1, 5.2, -.infinity, 0.4, 0.2, 0.1, -.infinity, -.infinity, -.infinity]) + + // After transcribe token with text and single timestamp (should force timestamp tokens) + let logits4 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2, 0.1, 0.1]) + let result4 = tokensFilter.filterLogits(logits4, withTokens: [0, 4, 0, 7]) + XCTAssertEqual(result4.data(for: 2), [-.infinity, -.infinity, -.infinity, -.infinity, -.infinity, -.infinity, -.infinity, 0.1, 0.1]) } // MARK: - VAD Tests