Skip to content

Commit

Permalink
fix: issue with isModelGenerating when switching between multiple mod…
Browse files Browse the repository at this point in the history
…els (#73)

## Description
There was a problem when user was switching between multiple llms in one
component, to fix this issue I removed code related to handling strict
mode problems(strict mode was causing problems with event listeners also
so I think we should ignore it as it isn't the best with background
tasks), I've also remove deleteModule native function as it wasn't
really doing anything, now user can download multiple llms within one
component and seamlessly switch between them without bugs.
The problem was mentioned in issue #42 

### Type of change
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [x] iOS
- [x] Android

### Testing instructions
<!-- Provide step-by-step instructions on how to test your changes.
Include setup details if necessary. -->

### Screenshots
<!-- Add screenshots here, if applicable -->

### Related issues
<!-- Link related issues here using #issue-number -->

### Checklist
- [x] I have performed a self-review of my code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [ ] My changes generate no new warnings

### Additional notes
<!-- Include any additional information, assumptions, or context that
reviewers might need to understand this PR. -->
  • Loading branch information
NorbertKlockiewicz authored Dec 20, 2024
1 parent e120d2b commit cf7f372
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 24 deletions.
5 changes: 3 additions & 2 deletions android/src/main/java/com/swmansion/rnexecutorch/LLM.kt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LLM(reactContext: ReactApplicationContext) :
private fun initializeLlamaModule(modelPath: String, tokenizerPath: String, promise: Promise) {
llamaModule = LlamaModule(1, modelPath, tokenizerPath, 0.7f)
isFetching = false
this.tempLlamaResponse.clear()
promise.resolve("Model loaded successfully")
}

Expand All @@ -74,8 +75,8 @@ class LLM(reactContext: ReactApplicationContext) :
contextWindowLength: Double,
promise: Promise
) {
if (llamaModule != null || isFetching) {
promise.reject("Model already loaded", "Model is already loaded or fetching")
if (isFetching) {
promise.reject("Model is fetching", "Model is fetching")
return
}

Expand Down
24 changes: 12 additions & 12 deletions ios/RnExecutorch/LLM.mm
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ - (instancetype)init {
isFetching = NO;
tempLlamaResponse = [[NSMutableString alloc] init];
}

return self;
}

Expand All @@ -38,7 +38,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt {
if ([token isEqualToString:prompt]) {
return;
}

dispatch_async(dispatch_get_main_queue(), ^{
[self emitOnToken:token];
[self->tempLlamaResponse appendString:token];
Expand All @@ -54,9 +54,9 @@ - (void)updateDownloadProgress:(NSNumber *)progress {
- (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject {
NSURL *modelURL = [NSURL URLWithString:modelSource];
NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource];
if(self->runner || isFetching){
reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil);

if(isFetching){
reject(@"model_is_fetching", @"Model is fetching", nil);
return;
}

Expand All @@ -78,10 +78,11 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou

modelFetcher.onFinish = ^(NSString *modelFilePath) {
self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath];
NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength);
NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength);

self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt];
self->isFetching = NO;
self->tempLlamaResponse = [NSMutableString string];
resolve(@"Model and tokenizer loaded successfully");
return;
};
Expand All @@ -94,23 +95,23 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou
- (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject {
[conversationManager addResponse:input senderRole:ChatRole::USER];
NSString *prompt = [conversationManager getConversation];

dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
NSError *error = nil;
[self->runner generate:prompt withTokenCallback:^(NSString *token) {
[self onResult:token prompt:prompt];
[self onResult:token prompt:prompt];
} error:&error];

// make sure to add eot token once generation is done
if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) {
[self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt];
}

if (self->tempLlamaResponse) {
[self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT];
self->tempLlamaResponse = [NSMutableString string];
}

if (error) {
reject(@"error_in_generation", error.localizedDescription, nil);
return;
Expand All @@ -120,7 +121,6 @@ - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve
});
}


-(void)interrupt {
[self->runner stop];
}
Expand Down
4 changes: 2 additions & 2 deletions ios/RnExecutorch/utils/LargeFileFetcher.mm
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ @implementation LargeFileFetcher {
- (instancetype)init {
self = [super init];
if (self) {
NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:@"com.swmansion.rnexecutorch"];
NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:[NSString stringWithFormat:@"com.swmansion.rnexecutorch.%@", [[NSUUID UUID] UUIDString]]];
_session = [NSURLSession sessionWithConfiguration:configuration delegate:self delegateQueue:nil];
}
return self;
Expand Down Expand Up @@ -111,7 +111,7 @@ - (void)startDownloadingFileFromURL:(NSURL *)url {

- (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTask *)downloadTask didFinishDownloadingToURL:(NSURL *)location {
NSFileManager *fileManager = [NSFileManager defaultManager];

[fileManager removeItemAtPath:_destination error:nil];

NSError *error;
Expand Down
10 changes: 3 additions & 7 deletions src/LLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,8 @@ export const useLLM = ({
const [downloadProgress, setDownloadProgress] = useState(0);
const downloadProgressListener = useRef<null | EventSubscription>(null);
const tokenGeneratedListener = useRef<null | EventSubscription>(null);
const initialized = useRef(false);

useEffect(() => {
if (initialized.current) {
return;
}

initialized.current = true;

const loadModel = async () => {
try {
let modelUrl = modelSource;
Expand All @@ -57,6 +50,7 @@ export const useLLM = ({
}
}
);
setIsReady(false);

await LLM.loadLLM(
modelUrl as string,
Expand All @@ -83,6 +77,8 @@ export const useLLM = ({
const message = (err as Error).message;
setIsReady(false);
setError(message);
} finally {
setDownloadProgress(0);
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/native/NativeLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ export interface Spec extends TurboModule {
contextWindowLength: number
): Promise<string>;
runInference(input: string): Promise<string>;
deleteModule(): void;
interrupt(): void;
deleteModule(): void;

readonly onToken: EventEmitter<string>;
readonly onDownloadProgress: EventEmitter<number>;
Expand Down

0 comments on commit cf7f372

Please sign in to comment.