Skip to content

Commit

Permalink
Use lazy imports to speed up initial loading time
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov committed Nov 11, 2024
1 parent fa42783 commit 12a1296
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
12 changes: 7 additions & 5 deletions r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
from .models import get_hf_llm, new_get_hf_llm, get_default_model
from .voice import tts
from .const import R2AI_HOMEDIR
from . import auto, LOGGER, logging
from . import LOGGER, logging
from .web import stop_http_server, server_running
from .progress import progress_bar
import litellm
from .completion import messages_to_prompt

file_dir = os.path.dirname(__file__)
sys.path.append(file_dir)
Expand Down Expand Up @@ -90,6 +88,7 @@ def ddg(m):
return f"Considering:\n```{res}\n```\n"

def is_litellm_model(model):
from litellm import models_by_provider
provider = None
model_name = None
if model.startswith ("/"):
Expand All @@ -98,7 +97,7 @@ def is_litellm_model(model):
provider, model_name = model.split(":")
elif "/" in model:
provider, model_name = model.split("/")
if provider in litellm.models_by_provider and model_name in litellm.models_by_provider[provider]:
if provider in models_by_provider and model_name in models_by_provider[provider]:
return True
return False

Expand Down Expand Up @@ -378,6 +377,7 @@ def respond(self):
# builtins.print(prompt)
response = None
if self.auto_run:
from . import auto
if(is_litellm_model(self.model)):
response = auto.chat(self)
else:
Expand Down Expand Up @@ -416,7 +416,8 @@ def respond(self):
# {"role": "system", "content": "You are a poetic assistant, be creative."},
# {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."}
# ]
completion = litellm.completion(
from litellm import completion as litellm_completion
completion = litellm_completion(
model=self.model.replace(":", "/"),
messages=self.messages,
max_completion_tokens=maxtokens,
Expand Down Expand Up @@ -452,6 +453,7 @@ def respond(self):
"max_tokens": maxtokens
}
if self.env["chat.rawdog"] == "true":
from .completion import messages_to_prompt
prompt = messages_to_prompt(self, messages)
response = self.llama_instance(prompt, **chat_args)
else:
Expand Down
10 changes: 4 additions & 6 deletions r2ai/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
from .utils import slurp, dump
from huggingface_hub import hf_hub_download, login
from huggingface_hub import HfApi, list_repo_tree, get_paths_info
from typing import Dict, List, Union
import appdirs
import builtins
import inquirer
import json
import llama_cpp
import os
import shutil
import subprocess
import sys
import traceback
from transformers import AutoTokenizer
from llama_cpp.llama_tokenizer import LlamaHFTokenizer

# DEFAULT_MODEL = "TheBloke/CodeLlama-34B-Instruct-GGUF"
# DEFAULT_MODEL = "TheBloke/llama2-7b-chat-codeCherryPop-qLoRA-GGUF"
# DEFAULT_MODEL = "-m TheBloke/dolphin-2_6-phi-2-GGUF"
Expand Down Expand Up @@ -263,6 +257,7 @@ def get_hf_llm(ai, repo_id, debug_mode, context_window):

# Check if model was originally split
split_files = [model["filename"] for model in raw_models if selected_model in model["filename"]]
from huggingface_hub import hf_hub_download
if len(split_files) > 1:
# Download splits
for split_file in split_files:
Expand Down Expand Up @@ -466,6 +461,7 @@ def list_gguf_files(repo_id: str) -> List[Dict[str, Union[str, float]]]:
"""

try:
from huggingface_hub import HfApi
api = HfApi()
tree = list(api.list_repo_tree(repo_id))
files_info = [file for file in tree if file.path.endswith('.gguf')]
Expand Down Expand Up @@ -634,7 +630,9 @@ def supports_metal():


def get_llama_inst(repo_id, **kwargs):
import llama_cpp
if 'functionary' in repo_id:
from llama_cpp.llama_tokenizer import LlamaHFTokenizer
kwargs['tokenizer'] = LlamaHFTokenizer.from_pretrained(repo_id)
filename = os.path.basename(kwargs.pop('model_path'))
kwargs['echo'] = True
Expand Down
3 changes: 1 addition & 2 deletions r2ai/pipe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import traceback
import r2pipe
from .progress import progress_bar


have_rlang = False
r2lang = None
Expand Down Expand Up @@ -63,7 +63,6 @@ def r2singleton():
def get_r2_inst():
return r2singleton()

@progress_bar("Loading", color="yellow")
def open_r2(file, flags=[]):
global r2, filename, r2lang
r2 = r2pipe.open(file, flags=flags)
Expand Down

0 comments on commit 12a1296

Please sign in to comment.