diff --git a/r2ai/auto.py b/r2ai/auto.py index 639b528..05ca2af 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -7,7 +7,7 @@ from litellm import _should_retry, acompletion, utils, ModelResponse import asyncio from .pipe import get_filename -from .tools import r2cmd, run_python, execute_binary +from .tools import r2cmd, run_python, execute_binary, schemas, print_tool_call import json import signal from .spinner import spinner @@ -40,7 +40,7 @@ """ class ChatAuto: - def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, stream=True, cb=None ): + def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=60, stream=True, cb=None ): self.logger = LOGGER self.functions = {} self.tools = [] @@ -63,9 +63,13 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp self.tool_choice = None if tools: for tool in tools: - f = utils.function_to_dict(tool) - self.tools.append({ "type": "function", "function": f }) - self.functions[f['name']] = tool + if tool.__name__ in schemas: + schema = schemas[tool.__name__] + else: + schema = utils.function_to_dict(tool) + + self.tools.append({ "type": "function", "function": schema }) + self.functions[tool.__name__] = tool self.tool_choice = tool_choice self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None #self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.' @@ -143,7 +147,9 @@ async def process_streaming_response(self, resp): self.cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True }) self.cb('message_stream', { "content": m if m else '', "id": 'message_' + chunk.id, 'done': done }) self.messages.append(current_message) - if len(current_message['tool_calls']) > 0: + if len(current_message['tool_calls']) == 0: + del current_message['tool_calls'] + else: await self.process_tool_calls(current_message['tool_calls']) return current_message @@ -247,8 +253,8 @@ async def get_completion(self): async def achat(self, messages=None) -> str: if messages: self.messages = messages + self.logger.debug(self.messages) response = await self.get_completion() - self.logger.debug(f'chat complete') return response def chat(self, **kwargs) -> str: @@ -261,25 +267,12 @@ def cb(type, data): sys.stdout.write(data['content']) elif type == 'tool_call': builtins.print() - if data['function']['name'] == 'r2cmd': - builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m') - elif data['function']['name'] == 'run_python': - builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m') - builtins.print(data['function']['arguments']['command']) - elif data['function']['name'] == 'execute_binary': - filename = get_filename() - stdin = data['function']['arguments']['stdin'] - args = data['function']['arguments']['args'] - cmd = filename - if len(args) > 0: - cmd += ' ' + ' '.join(args) - if stdin: - cmd += f' stdin={stdin}' - builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m') + print_tool_call(data) elif type == 'tool_response': if 'content' in data: sys.stdout.write(data['content']) sys.stdout.flush() + builtins.print() # builtins.print(data['content']) elif type == 'message' and data['done']: builtins.print() @@ -324,11 +317,4 @@ def chat(interpreter, **kwargs): finally: signal.signal(signal.SIGINT, original_handler) spinner.stop() - try: - pending = asyncio.all_tasks(loop=loop) - for task in pending: - task.cancel() - loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - loop.close() \ No newline at end of file + litellm.in_memory_llm_clients_cache.clear() \ No newline at end of file diff --git a/r2ai/main.py b/r2ai/main.py index 806dd6c..460018c 100755 --- a/r2ai/main.py +++ b/r2ai/main.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" import sys import builtins import traceback diff --git a/r2ai/repl.py b/r2ai/repl.py index ebab676..af3a772 100644 --- a/r2ai/repl.py +++ b/r2ai/repl.py @@ -14,7 +14,7 @@ from .interpreter import Interpreter from .pipe import have_rlang, r2lang, r2singleton from r2ai import bubble, LOGGER - +from .test import run_test tab_init() print_buffer = "" @@ -217,6 +217,10 @@ def runline(ai, usertext): print(help_message) elif usertext.startswith("clear") or usertext.startswith("-k"): print("\x1b[2J\x1b[0;0H\r") + if ai.messages: + ai.messages = [] + if autoai and autoai.messages: + autoai.messages = [] elif usertext.startswith("-MM"): print(models().strip()) elif usertext.startswith("-M"): @@ -469,6 +473,8 @@ def runline(ai, usertext): print("r2 is not available", file=sys.stderr) else: builtins.print(r2_cmd(usertext[1:])) + elif usertext.startswith("--test"): + run_test(usertext[7:]) elif usertext.startswith("-"): print(f"Unknown flag '{usertext}'. See 'r2ai -h' for help", file=sys.stderr) else: diff --git a/r2ai/test.py b/r2ai/test.py new file mode 100644 index 0000000..0a64ade --- /dev/null +++ b/r2ai/test.py @@ -0,0 +1,37 @@ +import builtins +from .tools import run_python, execute_binary, r2cmd +import subprocess +from .pipe import get_filename +import time +py_code = """ +print('hello test') +""" + +def run_test(args): + if not args or len(args) == 0: + res = run_python(py_code).strip() + print(f"run_python: {res}", len(res)) + assert res == "hello test" + print("run_python: test passed") + r2cmd("o--;o /bin/ls") + res = execute_binary(args=["-d", "/etc"]).strip() + subp = subprocess.run(["/bin/ls", "-d", "/etc"], capture_output=True, text=True) + print("exec result", res) + print("subp result", subp.stdout) + assert ''.join(res).strip() == subp.stdout.strip() + print("execute_binary with args: test passed") + else: + cmd, *args = args.split(" ", 1) + if cmd == "get_filename": + builtins.print(get_filename()) + elif cmd == "run_python": + builtins.print(f"--- args ---") + builtins.print(args) + builtins.print(f"--- end args ---") + builtins.print(f"--- result ---") + builtins.print(run_python(args[0])) + builtins.print(f"--- end result ---") + elif cmd == "r2cmd": + builtins.print(f"--- {args} ---") + builtins.print(r2cmd(args)) + builtins.print("--- end ---") diff --git a/r2ai/tools.py b/r2ai/tools.py index 6d89854..bafc032 100644 --- a/r2ai/tools.py +++ b/r2ai/tools.py @@ -1,6 +1,21 @@ from r2ai.pipe import get_r2_inst import json import builtins +import base64 +from .pipe import get_filename +from . import LOGGER +import time +import sys +from io import StringIO +import subprocess +import os +is_plugin = False +try: + import r2lang + is_plugin = True +except Exception: + is_plugin = False + pass def r2cmd(command: str): """ @@ -17,8 +32,11 @@ def r2cmd(command: str): The output of the r2 command """ r2 = get_r2_inst() + if command.startswith('r2 '): + return "You are already in r2!" cmd = '{"cmd":' + json.dumps(command) + '}' res = r2.cmd(cmd) + try: res = json.loads(res) if 'error' in res and res['error'] is True: @@ -29,6 +47,10 @@ def r2cmd(command: str): return res['res'] except json.JSONDecodeError: + if type(res) == str: + spl = res.strip().split('\n') + if spl[-1].startswith('{"res":""'): + res = '\n'.join(spl[:-1]) return res except Exception as e: # return { 'type': 'error', 'output': f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" } @@ -49,34 +71,105 @@ def run_python(command: str): The output of the python script """ r2 = get_r2_inst() - with open('r2ai_tmp.py', 'w') as f: - f.write(command) - r2 = get_r2_inst() - res = r2.cmd('#!python r2ai_tmp.py') - r2.cmd('rm r2ai_tmp.py') - return res + res = "" + is_plugin = False + python_path = sys.executable + try: + proc = subprocess.run([python_path, '-c', command], + capture_output=True, + text=True) + res = proc.stdout + if proc.stderr: + res += proc.stderr + except Exception as e: + res = str(e) + + # if is_plugin: + # base64cmd = base64.b64encode(command.encode('utf-8')).decode('utf-8') + # res += r2cmd(f'#!python -e base64:{base64cmd} > .r2ai_tmp.log') + # res += r2cmd('cat .r2ai_tmp.log') + # r2cmd('rm .r2ai_tmp.log') + # else: + # with open('r2ai_tmp.py', 'w') as f: + # f.write(command) + # r2 = get_r2_inst() + # res += r2cmd('#!python r2ai_tmp.py > .r2ai_tmp.log') + # time.sleep(0.1) + # res += r2cmd('!cat .r2ai_tmp.log') + # LOGGER.debug(f'run_python: {res}') + # # r2cmd('rm r2ai_tmp.py') + # # r2cmd('rm .r2ai_tmp.log') + return res + + +schemas = { + "execute_binary": { + "name": "execute_binary", + "description": "Execute a binary with the given arguments and stdin", + "parameters": { + "type": "object", + "properties": { + "args": { + "description": "The arguments to pass to the binary. Do not include the file name.", + "type": "array", + "items": { + "type": "string" + } + }, + "stdin": { + "type": "string" + } + } + } + } +} def execute_binary(args: list[str] = [], stdin: str = ""): - """ - Execute a binary with the given arguments and stdin + filename = get_filename() + if filename: + if os.path.isabs(filename): + abs_path = os.path.abspath(filename) + if os.path.exists(abs_path): + filename = abs_path + else: + cwd_path = os.path.join(os.getcwd(), filename) + if os.path.exists(cwd_path): + filename = cwd_path + try: + cmd = [filename] + args + proc = subprocess.run(cmd, input=stdin, capture_output=True, text=True) + res = proc.stdout + if proc.stderr: + res += proc.stderr + return res + except Exception as e: + return str(e) + return "" + # r2 = get_r2_inst() + # if stdin: + # r2.cmd(f'dor stdin={json.dumps(stdin)}') + # if len(args) > 0: + # r2.cmd(f"ood {' '.join(args)}") + # else: + # r2.cmd("ood") + # res = r2cmd("dc") + # return res - Parameters - ---------- - args: list[str] - The arguments to pass to the binary - stdin: str - The stdin to pass to the binary - Returns - ------- - str - The output of the binary - """ - r2 = get_r2_inst() - if len(args) > 0: - r2.cmd(f"dor {' '.join(args)}") - if stdin: - r2.cmd(f'dor stdin="{stdin}"') - r2.cmd("ood") - return r2cmd("dc") +def print_tool_call(msg): + if msg['function']['name'] == 'r2cmd': + builtins.print('\x1b[1;32m> \x1b[4m' + msg['function']['arguments']['command'] + '\x1b[0m') + elif msg['function']['name'] == 'run_python': + builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m') + builtins.print(msg['function']['arguments']['command']) + elif msg['function']['name'] == 'execute_binary': + filename = get_filename() or 'bin' + stdin = msg['function']['arguments']['stdin'] if 'stdin' in msg['function']['arguments'] else None + args = msg['function']['arguments']['args'] if 'args' in msg['function']['arguments'] else [] + cmd = filename + if args and len(args) > 0: + cmd += ' ' + ' '.join(args) + if stdin and len(stdin) > 0: + cmd += f' stdin={stdin}' + builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m') diff --git a/r2ai/ui/app.py b/r2ai/ui/app.py index 01ac4ec..a4e4656 100644 --- a/r2ai/ui/app.py +++ b/r2ai/ui/app.py @@ -22,6 +22,9 @@ from .chat import chat import asyncio import json +import re +ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + class ModelConfigDialog(SystemModalScreen): def __init__(self, keys: list[str]) -> None: @@ -197,15 +200,15 @@ def on_message(self, type: str, message: any) -> None: self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}") elif message['function']['name'] == 'execute_binary': args = message['function']['arguments'] - output = get_filename() + output = get_filename() or "bin" if 'args' in args and len(args['args']) > 0: - output += f" {args['args'].join(' ')}\n" + output += f" {' '.join(args['args'])}\n" if 'stdin' in args and len(args['stdin']) > 0: output += f" stdin={args['stdin']}\n" self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {output}") elif type == 'tool_response': - self.add_message(message["id"], "Tool Response", message['content']) + self.add_message(message["id"], "Tool Response", ANSI_REGEX.sub('', message['content'])) async def send_message(self) -> None: input_widget = self.query_one("#chat-input", Input) diff --git a/r2ai/ui/chat.py b/r2ai/ui/chat.py index 9b88299..c2f0b58 100644 --- a/r2ai/ui/chat.py +++ b/r2ai/ui/chat.py @@ -15,11 +15,11 @@ def signal_handler(signum, frame): async def chat(ai, message, cb): model = ai.model.replace(":", "/") tools = [r2cmd, run_python, execute_binary] - messages = ai.messages + [{"role": "user", "content": message}] + ai.messages.append({"role": "user", "content": message}) tool_choice = 'auto' if not is_litellm_model(model) and ai and not ai.llama_instance: ai.llama_instance = new_get_hf_llm(ai, model, int(ai.env["llm.window"])) - chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb) - + chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=ai.messages, tool_choice=tool_choice, cb=cb) + return await chat_auto.achat()