diff --git a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt index e12f027..393468b 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt @@ -64,6 +64,7 @@ class LLM(reactContext: ReactApplicationContext) : private fun initializeLlamaModule(modelPath: String, tokenizerPath: String, promise: Promise) { llamaModule = LlamaModule(1, modelPath, tokenizerPath, 0.7f) isFetching = false + this.tempLlamaResponse.clear() promise.resolve("Model loaded successfully") } @@ -74,8 +75,8 @@ class LLM(reactContext: ReactApplicationContext) : contextWindowLength: Double, promise: Promise ) { - if (llamaModule != null || isFetching) { - promise.reject("Model already loaded", "Model is already loaded or fetching") + if (isFetching) { + promise.reject("Model is fetching", "Model is fetching") return } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 2f270c6..8b6957f 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -28,7 +28,7 @@ - (instancetype)init { isFetching = NO; tempLlamaResponse = [[NSMutableString alloc] init]; } - + return self; } @@ -38,7 +38,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt { if ([token isEqualToString:prompt]) { return; } - + dispatch_async(dispatch_get_main_queue(), ^{ [self emitOnToken:token]; [self->tempLlamaResponse appendString:token]; @@ -54,9 +54,9 @@ - (void)updateDownloadProgress:(NSNumber *)progress { - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { NSURL *modelURL = [NSURL URLWithString:modelSource]; NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource]; - - if(self->runner || isFetching){ - reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil); + + if(isFetching){ + reject(@"model_is_fetching", @"Model is fetching", nil); return; } @@ -78,10 +78,11 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou modelFetcher.onFinish = ^(NSString *modelFilePath) { self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath]; - NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength); + NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength); self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt]; self->isFetching = NO; + self->tempLlamaResponse = [NSMutableString string]; resolve(@"Model and tokenizer loaded successfully"); return; }; @@ -94,23 +95,23 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { [conversationManager addResponse:input senderRole:ChatRole::USER]; NSString *prompt = [conversationManager getConversation]; - + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ NSError *error = nil; [self->runner generate:prompt withTokenCallback:^(NSString *token) { - [self onResult:token prompt:prompt]; + [self onResult:token prompt:prompt]; } error:&error]; - + // make sure to add eot token once generation is done if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) { [self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt]; } - + if (self->tempLlamaResponse) { [self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT]; self->tempLlamaResponse = [NSMutableString string]; } - + if (error) { reject(@"error_in_generation", error.localizedDescription, nil); return; @@ -120,7 +121,6 @@ - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve }); } - -(void)interrupt { [self->runner stop]; } diff --git a/ios/RnExecutorch/utils/LargeFileFetcher.mm b/ios/RnExecutorch/utils/LargeFileFetcher.mm index 6ae58db..48cc39b 100644 --- a/ios/RnExecutorch/utils/LargeFileFetcher.mm +++ b/ios/RnExecutorch/utils/LargeFileFetcher.mm @@ -12,7 +12,7 @@ @implementation LargeFileFetcher { - (instancetype)init { self = [super init]; if (self) { - NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:@"com.swmansion.rnexecutorch"]; + NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:[NSString stringWithFormat:@"com.swmansion.rnexecutorch.%@", [[NSUUID UUID] UUIDString]]]; _session = [NSURLSession sessionWithConfiguration:configuration delegate:self delegateQueue:nil]; } return self; @@ -111,7 +111,7 @@ - (void)startDownloadingFileFromURL:(NSURL *)url { - (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTask *)downloadTask didFinishDownloadingToURL:(NSURL *)location { NSFileManager *fileManager = [NSFileManager defaultManager]; - + [fileManager removeItemAtPath:_destination error:nil]; NSError *error; diff --git a/src/LLM.ts b/src/LLM.ts index 3cd67ff..4219fdc 100644 --- a/src/LLM.ts +++ b/src/LLM.ts @@ -30,15 +30,8 @@ export const useLLM = ({ const [downloadProgress, setDownloadProgress] = useState(0); const downloadProgressListener = useRef(null); const tokenGeneratedListener = useRef(null); - const initialized = useRef(false); useEffect(() => { - if (initialized.current) { - return; - } - - initialized.current = true; - const loadModel = async () => { try { let modelUrl = modelSource; @@ -57,6 +50,7 @@ export const useLLM = ({ } } ); + setIsReady(false); await LLM.loadLLM( modelUrl as string, @@ -83,6 +77,8 @@ export const useLLM = ({ const message = (err as Error).message; setIsReady(false); setError(message); + } finally { + setDownloadProgress(0); } }; diff --git a/src/native/NativeLLM.ts b/src/native/NativeLLM.ts index 23ee518..35d2a42 100644 --- a/src/native/NativeLLM.ts +++ b/src/native/NativeLLM.ts @@ -10,8 +10,8 @@ export interface Spec extends TurboModule { contextWindowLength: number ): Promise; runInference(input: string): Promise; - deleteModule(): void; interrupt(): void; + deleteModule(): void; readonly onToken: EventEmitter; readonly onDownloadProgress: EventEmitter;