Skip to content

Commit

Permalink
Add litellm and llama-cpp (with openai style response now)
Browse files Browse the repository at this point in the history
  • Loading branch information
nitanmarcel committed Sep 14, 2024
1 parent 82d5eb9 commit c1f0e2e
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 122 deletions.
4 changes: 2 additions & 2 deletions r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich.rule import Rule
from signal import signal, SIGINT

from .interprete_base import BaseInterpreter
from .interpreter_base import BaseInterpreter
from .large import Large
from .utils import merge_deltas
from .message_block import MessageBlock
Expand Down Expand Up @@ -732,7 +732,7 @@ def keywords_ai(self, text):
mm = None
return [word.strip() for word in text0.split(',')]

@progress_bar("Thinking", color="yellow")
#@progress_bar("Thinking", color="yellow")
def chat(self, message=None):
global print
global Ginterrupted
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions r2ai/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def wrapper(*args, **kwargs):
if server_running() and not server_in_background():
return func(*args, **kwargs)

with Progress(SpinnerColumn(), *Progress.get_default_columns(), console=Console(no_color=not bool(color)), transient=False) as p:
with Progress(SpinnerColumn(), *Progress.get_default_columns(), console=Console(no_color=not bool(color)), transient=True) as p:
task_text = f"[{color}]{text}" if color else text
task = p.add_task(
task_text, total=None if is_infinite else total)
Expand Down Expand Up @@ -58,7 +58,7 @@ def __enter__(self):
console=Console(
no_color=not bool(
self.color)),
transient=False)
transient=True)
if self.color:
self.task = self.progress.add_task(
f"[{self.color}]{self.text}", total=None if self.infinite else self.total)
Expand Down
2 changes: 1 addition & 1 deletion r2ai/r2clippy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from r2ai.r2clippy.chat import auto_chat
from r2ai.interprete_base import BaseInterpreter
from r2ai.interpreter_base import BaseInterpreter

Interpreter = None

Expand Down
Empty file removed r2ai/r2clippy/ais/__init__.py
Empty file.
33 changes: 0 additions & 33 deletions r2ai/r2clippy/ais/ai_openai.py

This file was deleted.

68 changes: 40 additions & 28 deletions r2ai/r2clippy/chat.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,47 @@

from r2ai import LOGGER
from r2ai.r2clippy.ais.ai_openai import chat as openai_chat
from r2ai.models import new_get_hf_llm
from r2ai.r2clippy.functions import get_ai_tools
from r2ai.r2clippy.models import parse_model_str
from r2ai.r2clippy.processors import process_streaming_response
from r2ai.r2clippy.decorators import system_message, context
from litellm import completion as litellm_completion
from r2ai.interpreter_base import BaseInterpreter
from r2ai.r2clippy.constants import LITELMM_PROVIDERS

def auto_chat(interpreter):
model = parse_model_str(interpreter.model)
fn = None
if model.id in auto_chat_handlers.get(model.platform):
if model.uri:
interpreter.api_base = model.uri
fn = auto_chat_handlers[model.platform][model.id]
elif "default" in auto_chat_handlers.get(model.platform, {}):
if model.uri:
interpreter.api_base = model.uri
fn = auto_chat_handlers[model.platform]["default"]
if not fn:
LOGGER.error("Model %s:%s is not currently supported in auto mode")
return
return fn(interpreter)


def chat_open_ai():
pass
from r2ai.r2clippy.utils.split_string import split_string_with_limit
from llama_cpp import Llama

def auto_chat(interpreter: BaseInterpreter):
model = parse_model_str(interpreter.model)
_auto_chat(interpreter, model)

auto_chat_handlers = {
"openai": {
"default": openai_chat
},
"openapi": {
"default": openai_chat
}
}
@system_message
@context
def _auto_chat(interpreter: BaseInterpreter, model):
call = True
while call:
extra_args = {}
completion = None
if model.platform not in LITELMM_PROVIDERS:
if not interpreter.llama_instance:
interpreter.llama_instance = new_get_hf_llm(interpreter, f"{model.platform}/{model.id}", (LOGGER.level / 10) == 1, int(interpreter.env["llm.window"]))
completion = interpreter.llama_instance.create_chat_completion_openai_v1
extra_args = {}
else:
completion = litellm_completion
extra_args = {"num_retries": 3,
"base_url": model.uri}
response = completion(
model=f"{model.platform}/{model.id}",
max_tokens=4000, #int(interpreter.env["llm.maxtokens"]),
tools=get_ai_tools(),
messages=interpreter.messages,
tool_choice="auto",
stream=True,
temperature=float(interpreter.env["llm.temperature"],
**extra_args)
)
call = process_streaming_response(
interpreter, response)
return response
27 changes: 27 additions & 0 deletions r2ai/r2clippy/chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from queue import Queue
from typing import Dict, Tuple, Union

from r2ai.r2clippy.utils import split_string_with_limit

_chunks: Queue = Queue()


def add_chunk(text: str, max_tokens = 2000):
global _chunks
if _chunks.qsize() > 0:
_chunks = Queue()
if text.strip() == "":
return 0
for i in split_string_with_limit(text, max_tokens, "cl100k_base"):
_chunks.put(i)
return _chunks.qsize()

def get_chunk():
global _chunks
if _chunks.qsize() == 0:
return str()
return _chunks.get()

def size():
global _chunks
return _chunks.qsize()
6 changes: 6 additions & 0 deletions r2ai/r2clippy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@
Think step by step.
Break down the task into steps and execute the necessary `radare2` commands in order to complete the task.
"""

LITELMM_PROVIDERS = [
'perplexity', 'friendliai', 'together_ai', 'groq', 'fireworks_ai', 'ai21_chat',
'deepinfra', 'anyscale', 'deepseek', 'codestral', 'mistral', 'nvidia_nim', 'ai21',
'empower', 'azure_ai', 'cerebras', 'volcengine', 'voyage', 'github'
]
3 changes: 1 addition & 2 deletions r2ai/r2clippy/decorators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from functools import wraps

from r2ai.interprete_base import BaseInterpreter
from r2ai.interpreter_base import BaseInterpreter
from r2ai.r2clippy.models import get_model_by_str
from r2ai.r2clippy.utils import context_from_msg

# TODO: context for each model


def context(func):
@wraps(func)
def wrapper(*args, **kwargs):
Expand Down
31 changes: 26 additions & 5 deletions r2ai/r2clippy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pydantic_core import ValidationError

from r2ai.pipe import get_r2_inst, r2lang
from r2ai.r2clippy.chunks import get_chunk, add_chunk, size
from r2ai import LOGGER

class _FunctionStorage:
def __init__(self):
Expand All @@ -14,8 +16,6 @@ def store(self):
def decorator(cls):
if cls not in self._storage:
self._storage.append(cls)
else:
print(self._storage)
return cls
return decorator

Expand All @@ -33,15 +33,20 @@ class R2Cmd(OpenAISchema):
the `#`, '#!', etc. commands. The output could be long, so try to use filters
if possible or limit. This is your preferred tool
"""
command: str = Field(default="!echo hi",
description="radare2 command to run")
command: str = Field(description="radare2 command to run")

@computed_field
def result(self) -> str:
r2 = get_r2_inst()
print("Running %s" % self.command)
res = r2.cmd(self.command)
print(res)
add_chunk(res)
res = get_chunk()
chunk_size = size()
if chunk_size > 0:
res+= f"\nChunked message. Remaining chunks: {chunk_size}. Use RetriveChunk to retrive the next chunk."
LOGGER.getChild("auto").info("Response has been chunked. Nr of chunks: %s", chunk_size)
return res


Expand All @@ -58,10 +63,27 @@ def result(self) -> str:
print(self.snippet)
r2lang.cmd('#!python r2ai_tmp.py > $tmp')
res = r2lang.cmd('cat $tmp')
add_chunk(res)
res = get_chunk()
chunk_size = size()
if chunk_size > 0:
res+= f"\nChunked message. Remaining chunks: {chunk_size}. Use RetriveChunk to retrive the next chunk."
r2lang.cmd('rm r2ai_tmp.py')
print(res)
return res

@FunctionStorage.store()
class RetriveChunk(OpenAISchema):
"""gets a chunk of a chunked message."""

@computed_field
def result(self) -> str:
res = get_chunk()
chunk_size = size()
if chunk_size > 0:
res+=f"\nChunked message. Remaining chunks: {chunk_size}. Use RetriveChunk to retrive the next chunk."
LOGGER.getChild("auto").info("Remaining chunks: %s", chunk_size)
return res

def get_ai_tools() -> Dict[str, str]:
tools = []
Expand All @@ -74,7 +96,6 @@ def get_ai_tools() -> Dict[str, str]:
)
return tools


def validate_ai_tool(arguments: Dict[str, str]) -> OpenAISchema:
tools = FunctionStorage.get_all()
original_exception = None
Expand Down
9 changes: 9 additions & 0 deletions r2ai/r2clippy/llamachatcompletion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# from openai.types.chat import ChatCompletion


# class LlamaChatCompletion(ChatCompletion):
# @classmethod
# def from_llama_response(llamaResponse: dict) -> ChatCompletion:
# choice = llamaResponse["choices"][0]
# if "message" in choice:
# if "content" in choice["message"]:
56 changes: 8 additions & 48 deletions r2ai/r2clippy/ais/processors.py → r2ai/r2clippy/processors.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,23 @@
import json
import sys

from openai.types.chat import ChatCompletionChunk
from typing import Union
from openai.types.chat import ChatCompletion
from litellm.types.utils import ModelResponse
from pydantic_core import ValidationError

from r2ai import LOGGER
from r2ai.r2clippy.constants import ANSI_REGEX
from r2ai.r2clippy.functions import PythonCmd, R2Cmd, validate_ai_tool

_retries = 0
_chunks = [] # workaround for retrying


def process_streaming_response(interpreter, response, max_retries=0):
try:
return _process_streaming_response(interpreter, response)
except Exception as e:
global _retries
if _retries == max_retries:
raise e
if not _chunks:
raise e
_retries += 1
LOGGER.getChild("r2clippy").info("Got invalid response %s Retrying", e)
choice = _chunks[-1].choices[0]
if hasattr(choice, "delta"):
delta = choice.delta
if hasattr(delta, "tool_calls") and delta.tool_calls:
delta_tool_calls = delta.tool_calls[0]
fn_delta = delta_tool_calls.function
tool_call_id = delta_tool_calls.id
if tool_call_id is not None:
interpreter.messages.append(
{
"role": "tool",
"content": f"Validation Error found:\n{e}\nRecall the function correctly, fix the errors".strip(),
"name": fn_delta.name,
"tool_call_id": tool_call_id
}
)
return True
interpreter.messages.append(
{
"role": "user",
"content": f"Validation Error found:\n{e}\nRecall the function correctly, fix the errors".strip()
}
)
return True


def _process_streaming_response(interpreter, response) -> bool:
def process_streaming_response(interpreter, response) -> bool:
"""Process streaming response.
Returns True if a chat call should be done
"""
tool_calls = []
msgs = []
chunk: ChatCompletionChunk
global _chunks
_chunks = []
chunk: Union[ModelResponse, ChatCompletion]
for chunk in response:
_chunks.append(chunk)
delta = None
choice = chunk.choices[0]
if hasattr(choice, "delta"):
Expand All @@ -77,7 +35,7 @@ def _process_streaming_response(interpreter, response) -> bool:
tool_calls.append({
"function": {
"arguments": "",
"name": fn_delta.name,
"name": fn_delta.name.split(".")[-1], # For some reason, sometimes the nameas are set as: function.FunctionName
},
"id": tool_call_id,
"type": "function"
Expand Down Expand Up @@ -128,6 +86,8 @@ def process_tool_calls(interpreter, tool_calls):
raise ValueError("Tool name must not be null")
if not tool_id:
raise ValueError("Tool id must not be null")

tool_name = tool_name.split(".")[-1] # For some reason, sometimes the nameas are set as: function.FunctionName

msg = {
"role": "tool",
Expand Down
4 changes: 4 additions & 0 deletions r2ai/r2clippy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from r2ai.r2clippy.utils.context import context_from_msg
from r2ai.r2clippy.utils.split_string import split_string_with_limit


2 changes: 1 addition & 1 deletion r2ai/r2clippy/utils.py → r2ai/r2clippy/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def context_from_msg(msg: dict):
False, False, False, False, use_vectordb)
if not matches:
return None
return "context: " + ", ".join(matches)
return "context: " + ", ".join(matches)
Loading

0 comments on commit c1f0e2e

Please sign in to comment.