From cf7f37274d3104d11ccc84bd32a6ab433fa4c999 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 20 Dec 2024 17:39:09 +0100 Subject: [PATCH] fix: issue with isModelGenerating when switching between multiple models (#73) ## Description There was a problem when user was switching between multiple llms in one component, to fix this issue I removed code related to handling strict mode problems(strict mode was causing problems with event listeners also so I think we should ignore it as it isn't the best with background tasks), I've also remove deleteModule native function as it wasn't really doing anything, now user can download multiple llms within one component and seamlessly switch between them without bugs. The problem was mentioned in issue #42 ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions ### Screenshots ### Related issues ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes --- .../java/com/swmansion/rnexecutorch/LLM.kt | 5 ++-- ios/RnExecutorch/LLM.mm | 24 +++++++++---------- ios/RnExecutorch/utils/LargeFileFetcher.mm | 4 ++-- src/LLM.ts | 10 +++----- src/native/NativeLLM.ts | 2 +- 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt index e12f027..393468b 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt @@ -64,6 +64,7 @@ class LLM(reactContext: ReactApplicationContext) : private fun initializeLlamaModule(modelPath: String, tokenizerPath: String, promise: Promise) { llamaModule = LlamaModule(1, modelPath, tokenizerPath, 0.7f) isFetching = false + this.tempLlamaResponse.clear() promise.resolve("Model loaded successfully") } @@ -74,8 +75,8 @@ class LLM(reactContext: ReactApplicationContext) : contextWindowLength: Double, promise: Promise ) { - if (llamaModule != null || isFetching) { - promise.reject("Model already loaded", "Model is already loaded or fetching") + if (isFetching) { + promise.reject("Model is fetching", "Model is fetching") return } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 2f270c6..8b6957f 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -28,7 +28,7 @@ - (instancetype)init { isFetching = NO; tempLlamaResponse = [[NSMutableString alloc] init]; } - + return self; } @@ -38,7 +38,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt { if ([token isEqualToString:prompt]) { return; } - + dispatch_async(dispatch_get_main_queue(), ^{ [self emitOnToken:token]; [self->tempLlamaResponse appendString:token]; @@ -54,9 +54,9 @@ - (void)updateDownloadProgress:(NSNumber *)progress { - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { NSURL *modelURL = [NSURL URLWithString:modelSource]; NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource]; - - if(self->runner || isFetching){ - reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil); + + if(isFetching){ + reject(@"model_is_fetching", @"Model is fetching", nil); return; } @@ -78,10 +78,11 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou modelFetcher.onFinish = ^(NSString *modelFilePath) { self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath]; - NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength); + NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength); self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt]; self->isFetching = NO; + self->tempLlamaResponse = [NSMutableString string]; resolve(@"Model and tokenizer loaded successfully"); return; }; @@ -94,23 +95,23 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { [conversationManager addResponse:input senderRole:ChatRole::USER]; NSString *prompt = [conversationManager getConversation]; - + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ NSError *error = nil; [self->runner generate:prompt withTokenCallback:^(NSString *token) { - [self onResult:token prompt:prompt]; + [self onResult:token prompt:prompt]; } error:&error]; - + // make sure to add eot token once generation is done if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) { [self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt]; } - + if (self->tempLlamaResponse) { [self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT]; self->tempLlamaResponse = [NSMutableString string]; } - + if (error) { reject(@"error_in_generation", error.localizedDescription, nil); return; @@ -120,7 +121,6 @@ - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve }); } - -(void)interrupt { [self->runner stop]; } diff --git a/ios/RnExecutorch/utils/LargeFileFetcher.mm b/ios/RnExecutorch/utils/LargeFileFetcher.mm index 6ae58db..48cc39b 100644 --- a/ios/RnExecutorch/utils/LargeFileFetcher.mm +++ b/ios/RnExecutorch/utils/LargeFileFetcher.mm @@ -12,7 +12,7 @@ @implementation LargeFileFetcher { - (instancetype)init { self = [super init]; if (self) { - NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:@"com.swmansion.rnexecutorch"]; + NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:[NSString stringWithFormat:@"com.swmansion.rnexecutorch.%@", [[NSUUID UUID] UUIDString]]]; _session = [NSURLSession sessionWithConfiguration:configuration delegate:self delegateQueue:nil]; } return self; @@ -111,7 +111,7 @@ - (void)startDownloadingFileFromURL:(NSURL *)url { - (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTask *)downloadTask didFinishDownloadingToURL:(NSURL *)location { NSFileManager *fileManager = [NSFileManager defaultManager]; - + [fileManager removeItemAtPath:_destination error:nil]; NSError *error; diff --git a/src/LLM.ts b/src/LLM.ts index 3cd67ff..4219fdc 100644 --- a/src/LLM.ts +++ b/src/LLM.ts @@ -30,15 +30,8 @@ export const useLLM = ({ const [downloadProgress, setDownloadProgress] = useState(0); const downloadProgressListener = useRef(null); const tokenGeneratedListener = useRef(null); - const initialized = useRef(false); useEffect(() => { - if (initialized.current) { - return; - } - - initialized.current = true; - const loadModel = async () => { try { let modelUrl = modelSource; @@ -57,6 +50,7 @@ export const useLLM = ({ } } ); + setIsReady(false); await LLM.loadLLM( modelUrl as string, @@ -83,6 +77,8 @@ export const useLLM = ({ const message = (err as Error).message; setIsReady(false); setError(message); + } finally { + setDownloadProgress(0); } }; diff --git a/src/native/NativeLLM.ts b/src/native/NativeLLM.ts index 23ee518..35d2a42 100644 --- a/src/native/NativeLLM.ts +++ b/src/native/NativeLLM.ts @@ -10,8 +10,8 @@ export interface Spec extends TurboModule { contextWindowLength: number ): Promise; runInference(input: string): Promise; - deleteModule(): void; interrupt(): void; + deleteModule(): void; readonly onToken: EventEmitter; readonly onDownloadProgress: EventEmitter;