Skip to content

Commit

Permalink
turbo mode added
Browse files Browse the repository at this point in the history
  • Loading branch information
fguzman82 committed Jul 9, 2024
1 parent bdb8616 commit cf0e2a5
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 36 deletions.
37 changes: 34 additions & 3 deletions CLIP-Finder2/CLIPImageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ final class CLIPImageModel {

func performInference(_ pixelBuffer: CVPixelBuffer) async throws -> MLMultiArray? {
guard let model = model else {
throw NSError(domain: "DataModel", code: 2, userInfo: [NSLocalizedDescriptionKey: "Model is not loaded"])
throw NSError(domain: "ClipImageModel", code: 2, userInfo: [NSLocalizedDescriptionKey: "Model is not loaded"])
}

let input = InputFeatureProvider(pixelBuffer: pixelBuffer)
Expand All @@ -72,15 +72,46 @@ final class CLIPImageModel {
if let multiArray = outputFeatures.featureValue(for: "var_1259")?.multiArrayValue {
return multiArray
} else {
throw NSError(domain: "DataModel", code: 3, userInfo: [NSLocalizedDescriptionKey: "Failed to retrieve MLMultiArray from prediction"])
throw NSError(domain: "ClipImageModel", code: 3, userInfo: [NSLocalizedDescriptionKey: "Failed to retrieve MLMultiArray from prediction"])
}
} catch {
#if DEBUG
print("Failed to perform inference: \(error)")
print("ClipImageModel: Failed to perform inference: \(error)")
#endif
throw error
}
}

func performInferenceSync(_ pixelBuffer: CVPixelBuffer) -> MLMultiArray? {
guard let model else {
#if DEBUG
print("ClipImageModel is not loaded.")
#endif
return nil
}

let input = InputFeatureProvider(pixelBuffer: pixelBuffer)
do {
let outputFeatures = try model.prediction(from: input)

if let multiArray = outputFeatures.featureValue(for: "var_1259")?.multiArrayValue {
return multiArray
}
else {
#if DEBUG
print("ClipImageModel: Failed to retrieve MLMultiArray.")
#endif
return nil
}
} catch {
#if DEBUG
print("ClipImageModel: Failed to perform inference: \(error)")
#endif
return nil
}

}


}

Expand Down
81 changes: 68 additions & 13 deletions CLIP-Finder2/CameraPreviewView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct CameraPreviewView: View {
@State private var showFocusCircle = false
@State private var isPaused: Bool = false
@State private var lastFrame: UIImage?
@State private var showTurboModeAlert = false

var body: some View {
ZStack {
Expand Down Expand Up @@ -53,16 +54,36 @@ struct CameraPreviewView: View {

Spacer()

CameraButton(action: {
photoGalleryViewModel.togglePause()
}, imageName: photoGalleryViewModel.isPaused ? "play.circle.fill" : "pause.circle.fill")
.frame(width: 60, height: 60)
HStack {
Spacer()
CameraButton(action: {
photoGalleryViewModel.togglePause()
}, imageName: photoGalleryViewModel.isPaused ? "play.circle.fill" : "pause.circle.fill")
.frame(width: 60, height: 60)
Spacer()
TurboButton(viewModel: photoGalleryViewModel, showAlert: $showTurboModeAlert)
}
.padding(.horizontal)
.padding(.bottom, 10)
}

if showFocusCircle, let point = focusPoint {
FocusCircleView(point: point)
}
}
.alert(isPresented: $showTurboModeAlert) {
Alert(
title: Text("Activate Turbo Mode"),
message: Text("Turbo Mode enables asynchronous CLIP image prediction. It's faster but may freeze the app (beta feature)."),
primaryButton: .default(Text("Activate")) {
photoGalleryViewModel.toggleTurboMode()
photoGalleryViewModel.finalizeTurboToggle()
},
secondaryButton: .cancel {
photoGalleryViewModel.finalizeTurboToggle()
}
)
}
.onAppear {
photoGalleryViewModel.startCamera()
photoGalleryViewModel.onFrameCaptured = { image in
Expand All @@ -81,15 +102,12 @@ struct CameraButton: View {

var body: some View {
Button(action: action) {
ZStack {
Circle()
.fill(Color.black.opacity(0.5))
.frame(width: 45, height: 45)

Image(systemName: imageName)
.font(.system(size: 20))
.foregroundColor(.white)
}
Image(systemName: imageName)
.font(.system(size: 30))
.foregroundColor(.white)
.frame(width: 45, height: 45)
.background(Color.black.opacity(0.5))
.clipShape(Circle())
}
}
}
Expand Down Expand Up @@ -190,3 +208,40 @@ struct CameraPreview: UIViewRepresentable {
}
}
}


struct TurboButton: View {
@ObservedObject var viewModel: PhotoGalleryViewModel
@Binding var showAlert: Bool

var body: some View {
Button(action: {
if !viewModel.useAsyncImageSearch {
viewModel.prepareTurboToggle()
showAlert = true
} else {
viewModel.toggleTurboMode()
viewModel.finalizeTurboToggle()
}
}) {
ZStack {
Circle()
.fill(viewModel.useAsyncImageSearch ? Color.yellow : Color.gray)
.frame(width: 40, height: 40)

Image(systemName: "bolt.fill")
.foregroundColor(viewModel.useAsyncImageSearch ? .black : .white)
.font(.system(size: 20))
}
}
.overlay(
Text("Turbo")
.font(.system(size: 10))
.foregroundColor(.white)
.padding(2)
.background(Color.black.opacity(0.6))
.cornerRadius(4)
.offset(y: 25)
)
}
}
9 changes: 5 additions & 4 deletions CLIP-Finder2/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ struct ContentView: View {
var body: some View {
NavigationView {
GeometryReader { geometry in
ZStack {
VStack {
VStack {
if photoGalleryViewModel.isGalleryEmpty {
Text("Your photo gallery is empty. Add some photos to use CLIP-Finder.")
.padding()
} else {
HStack {
TextField("Enter search text", text: $searchText)
.textFieldStyle(RoundedBorderTextFieldStyle())
Expand Down Expand Up @@ -92,10 +95,8 @@ struct ContentView: View {
.padding()
}
}

}


}
}
.navigationTitle("CLIP-Finder")
Expand Down
2 changes: 0 additions & 2 deletions CLIP-Finder2/CoreDataManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,3 @@ class CoreDataManager {
}
}



106 changes: 92 additions & 14 deletions CLIP-Finder2/PhotoGalleryViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import CoreImage
class PhotoGalleryViewModel: ObservableObject {
@Published var assets: [PHAsset] = []
@Published var topPhotoIDs: [String] = []
@Published var isGalleryEmpty: Bool = true

private var customTokenizer: CLIPTokenizer?
private var clipTextModel: CLIPTextModel
Expand All @@ -31,9 +32,11 @@ class PhotoGalleryViewModel: ObservableObject {
@Published var processedPhotosCount: Int = 0
@Published var totalPhotosCount: Int = 0
@Published var isProcessing: Bool = false
@Published var useAsyncImageSearch: Bool = false
private var wasPlayingBeforeTurbo: Bool = false

private var updateTimer: Timer?


init() {
self.cameraManager = CameraManager()
Expand All @@ -46,10 +49,10 @@ class PhotoGalleryViewModel: ObservableObject {
private func setupCameraManager() {
cameraManager.onFrameCaptured = { [weak self] ciImage in
guard let self = self, self.isCameraActive, !self.isPaused else { return }
// self.performImageSearch(from: ciImage)
Task {
await self.performImageSearch(from: ciImage)
}
self.performImageSearch(from: ciImage)
// Task {
// await self.performImageSearch(from: ciImage)
// }

self.onFrameCaptured?(ciImage)
}
Expand Down Expand Up @@ -86,6 +89,23 @@ class PhotoGalleryViewModel: ObservableObject {
}
}

func prepareTurboToggle() {
wasPlayingBeforeTurbo = !isPaused
if !isPaused {
togglePause()
}
}

func finalizeTurboToggle() {
if wasPlayingBeforeTurbo && isPaused {
togglePause()
}
}

func toggleTurboMode() {
useAsyncImageSearch.toggle()
}

func getCameraSession() -> AVCaptureSession {
return cameraManager.session
}
Expand Down Expand Up @@ -123,6 +143,10 @@ class PhotoGalleryViewModel: ObservableObject {
}
}
}

private func updateGalleryStatus() {
isGalleryEmpty = assets.isEmpty
}

private func setupTokenizer() {
guard let bpePath = Bundle.main.path(forResource: "bpe_simple_vocab_16e6", ofType: "txt") else {
Expand All @@ -134,6 +158,17 @@ class PhotoGalleryViewModel: ObservableObject {
func processTextSearch(_ searchText: String) {
searchTask?.cancel()

guard !searchText.isEmpty else {
#if DEBUG
print("Search text is empty, skipping search")
#endif

DispatchQueue.main.async {
self.topPhotoIDs = []
}
return
}

searchTask = Task {

try? await Task.sleep(nanoseconds: 300_000_000) // 300ms
Expand All @@ -147,13 +182,20 @@ class PhotoGalleryViewModel: ObservableObject {
}

private func performSearch(_ searchText: String) {
guard !isGalleryEmpty else {
#if DEBUG
print("Cannot perform search: Photo gallery is empty")
#endif
return
}

guard let tokenizer = customTokenizer else {
#if DEBUG
print("Tokenizer not initialized")
#endif
return
}

let tokens = tokenizer.tokenize(texts: [searchText])

Task {
Expand All @@ -176,7 +218,39 @@ class PhotoGalleryViewModel: ObservableObject {
}
}

func performImageSearch(from ciImage: CIImage) async {
func performImageSearch(from ciImage: CIImage) {
if useAsyncImageSearch {
Task {
await performImageSearchAsync(from: ciImage)
}
} else {
performImageSearchSync(from: ciImage)
}
}

func performImageSearchSync(from ciImage: CIImage) {
guard !isGalleryEmpty else {
#if DEBUG
print("Cannot perform search: Photo gallery is empty")
#endif
return
}
guard isCameraActive else { return }
guard let cgImage = CIContext().createCGImage(ciImage, from: ciImage.extent) else { return }
let uiImage = UIImage(cgImage: cgImage)

guard let pixelBuffer = Preprocessing.preprocessImage(uiImage, targetSize: CGSize(width: 256, height: 256)) else { return }

guard let imageFeatures = clipImageModel.performInferenceSync(pixelBuffer) else { return }

let topIDs = calculateAndPrintTopPhotoIDs(textFeatures: imageFeatures)
DispatchQueue.main.async {
self.topPhotoIDs = topIDs
}
}

// Async implementation of performImageSearch
func performImageSearchAsync(from ciImage: CIImage) async {
guard isCameraActive else {
#if DEBUG
print("Camera is not active, skipping image search")
Expand Down Expand Up @@ -240,15 +314,18 @@ class PhotoGalleryViewModel: ObservableObject {
}
DispatchQueue.main.async {
self.assets = assets
self.updateGalleryStatus()
// self.processAndCachePhotos()
profileAsync("processAndCachePhotos") { done in
self.processAndCachePhotos {
done()
if !self.isGalleryEmpty {
profileAsync("processAndCachePhotos") { done in
self.processAndCachePhotos {
done()
}
} completion: { time in
#if DEBUG
print("Process and cache completted in \(time) ms")
#endif
}
} completion: { time in
#if DEBUG
print("Process and cache completted in \(time) ms")
#endif
}
}
}
Expand Down Expand Up @@ -375,6 +452,7 @@ class PhotoGalleryViewModel: ObservableObject {
}
}

// Post-processing function in MPSGraph for calculating similarities and selecting TopPhotosIDs
private func calculateAndPrintTopPhotoIDs(textFeatures: MLMultiArray) -> [String] {
guard let device = MTLCreateSystemDefaultDevice() else {
fatalError("Metal is not supported on this device")
Expand Down
Loading

0 comments on commit cf0e2a5

Please sign in to comment.