Skip to content

Commit

Permalink
Backwards compatible generic model io
Browse files Browse the repository at this point in the history
* Support generic io for model inputs and outputs

* Add speed factor to timing report

* Use actor for early stop checks for better concurrency safety

* Add io type protocol handling and tests

* Formatting

* Fix timestamp token filter logic and tests

* Run unit tests on any branch in PR

* Upload test failure results
  • Loading branch information
ZachNagengast committed Dec 19, 2024
1 parent f63313f commit 2b4c011
Show file tree
Hide file tree
Showing 15 changed files with 712 additions and 243 deletions.
1 change: 0 additions & 1 deletion .github/workflows/development-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Development Tests

on:
pull_request:
branches: ["main"]
pull_request_review:
types: [submitted]
workflow_dispatch:
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,19 @@ jobs:
sleep 15
xcrun simctl list devices
- name: Build and Test - ${{ matrix.run-config['name'] }}
id: test-step
if: ${{ matrix.run-config['condition'] == true }}
continue-on-error: true
run: |
set -o pipefail
xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}'
- name: Upload Test Results
if: failure() && steps.test-step.outcome == 'failure'
uses: actions/upload-artifact@v4
with:
name: test-results-${{ matrix.run-config['name'] }}
path: |
~/Library/Developer/Xcode/DerivedData/**/Logs/Test/*.xcresult
retention-days: 5
8 changes: 6 additions & 2 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public extension AudioProcessing {
}

static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
guard startIndex >= 0 && startIndex < audioArray.count else {
guard startIndex >= 0, startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
}
Expand Down Expand Up @@ -228,7 +228,11 @@ public class AudioProcessor: NSObject, AudioProcessing {
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameCount) else {
throw WhisperError.loadAudioFailed("Unable to create audio buffer")
}
try audioFile.read(into: buffer, frameCount: frameCount)
do {
try audioFile.read(into: buffer, frameCount: frameCount)
} catch {
throw WhisperError.loadAudioFailed("Failed to read audio file: \(error)")
}
outputBuffer = buffer
} else {
// Audio needs resampling to 16khz
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ open class VoiceActivityDetector {
}
}

// MARK - Utility
// MARK: - Utility

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
Expand Down
18 changes: 15 additions & 3 deletions Sources/WhisperKit/Core/AudioEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

import CoreML

public protocol AudioEncoderOutputType {}
extension MLMultiArray: AudioEncoderOutputType {}

/// AudioEncoding protocol defines the requirements for an audio encoding implementation.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public protocol AudioEncoding {
/// The size of the embedding produced by the encoder.
var embedSize: Int? { get }

/// Encodes the given audio features asynchronously.
/// - Parameter features: The audio features to be encoded.
/// - Returns: An optional `MLMultiArray` containing the encoded features.
func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray?
/// - Returns: An optional tensor containing the encoded features.
func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)?
}

/// Backwards-compatible AudioEncoder implementation
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class AudioEncoder: AudioEncoding, WhisperMLModel {
public var model: MLModel?
Expand All @@ -36,8 +41,15 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel {

public init() {}

public func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)? {
guard let features = features as? MLMultiArray else {
throw WhisperError.audioProcessingFailed("AudioEncoder input must be MLMultiArray")
}

return try await encodeFeatures(features)
}

public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? {
// Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000)
guard let model else {
throw WhisperError.modelsUnavailable()
}
Expand Down
7 changes: 6 additions & 1 deletion Sources/WhisperKit/Core/FeatureExtractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ import CoreGraphics
import CoreML
import Foundation

public protocol FeatureExtractorOutputType {}
extension MLMultiArray: FeatureExtractorOutputType {}

public protocol FeatureExtracting {
associatedtype OutputType: FeatureExtractorOutputType

var melCount: Int? { get }
var windowSamples: Int? { get }
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray?
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> OutputType?
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
Expand Down
3 changes: 2 additions & 1 deletion Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
}
}

public struct DecodingInputs {
open class DecodingInputs {
public var initialPrompt: [Int]
public var inputIds: MLMultiArray
public var cacheLength: MLMultiArray
Expand Down Expand Up @@ -580,6 +580,7 @@ public struct TranscriptionResult: Codable {
Total Tokens: \(totalTokens)
Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s
Real Time Factor: \(String(format: "%.3f", rtf))
Speed Factor: \(String(format: "%.3f", 1.0 / rtf))
Fallbacks: \(timings.totalDecodingFallbacks)
""")
}
Expand Down
3 changes: 2 additions & 1 deletion Sources/WhisperKit/Core/Text/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ open class TimestampRulesFilter: LogitsFiltering {

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard let sampleBegin = sampleBegin(for: tokens),
sampleBegin > tokens.count
sampleBegin <= tokens.count
else {
// Early return if we are still prefilling the prompt
return logits
}

Expand Down
Loading

0 comments on commit 2b4c011

Please sign in to comment.