Skip to content

Commit

Permalink
Add new execute_binary command for auto; some UI bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov authored and trufae committed Nov 7, 2024
1 parent ce5b851 commit 59f891f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
16 changes: 14 additions & 2 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import litellm
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from .tools import r2cmd, run_python
from .pipe import get_filename
from .tools import r2cmd, run_python, execute_binary
import json
import signal
from .spinner import spinner
Expand Down Expand Up @@ -259,11 +260,22 @@ def cb(type, data):
if 'content' in data:
sys.stdout.write(data['content'])
elif type == 'tool_call':
builtins.print()
if data['function']['name'] == 'r2cmd':
builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m')
elif data['function']['name'] == 'run_python':
builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m')
builtins.print(data['function']['arguments']['command'])
elif data['function']['name'] == 'execute_binary':
filename = get_filename()
stdin = data['function']['arguments']['stdin']
args = data['function']['arguments']['args']
cmd = filename
if len(args) > 0:
cmd += ' ' + ' '.join(args)
if stdin:
cmd += f' stdin={stdin}'
builtins.print('\x1b[1;32m> \x1b[4m' + cmd + '\x1b[0m')
elif type == 'tool_response':
if 'content' in data:
sys.stdout.write(data['content'])
Expand All @@ -277,7 +289,7 @@ def signal_handler(signum, frame):

def chat(interpreter, **kwargs):
model = interpreter.model.replace(":", "/")
tools = [r2cmd, run_python]
tools = [r2cmd, run_python, execute_binary]
messages = interpreter.messages
tool_choice = 'auto'

Expand Down
27 changes: 27 additions & 0 deletions r2ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def r2cmd(command: str):
return log_messages

return res['res']
except json.JSONDecodeError:
return res
except Exception as e:
# return { 'type': 'error', 'output': f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" }
return f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}"
Expand All @@ -53,3 +55,28 @@ def run_python(command: str):
res = r2.cmd('#!python r2ai_tmp.py')
r2.cmd('rm r2ai_tmp.py')
return res

def execute_binary(args: list[str] = [], stdin: str = ""):
"""
Execute a binary with the given arguments and stdin
Parameters
----------
args: list[str]
The arguments to pass to the binary
stdin: str
The stdin to pass to the binary
Returns
-------
str
The output of the binary
"""

r2 = get_r2_inst()
if len(args) > 0:
r2.cmd(f"dor {' '.join(args)}")
if stdin:
r2.cmd(f'dor stdin="{stdin}"')
r2.cmd("ood")
return r2cmd("dc")
14 changes: 12 additions & 2 deletions r2ai/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,17 @@ def on_message(self, type: str, message: any) -> None:
except NoMatches:
existing = self.add_message(message["id"], "AI", message["content"])
elif type == 'tool_call':
self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}")
if 'command' in message['function']['arguments']:
self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}")
elif message['function']['name'] == 'execute_binary':
args = message['function']['arguments']
output = get_filename()
if 'args' in args and len(args['args']) > 0:
output += f" {args['args'].join(' ')}\n"
if 'stdin' in args and len(args['stdin']) > 0:
output += f" stdin={args['stdin']}\n"

self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {output}")
elif type == 'tool_response':
self.add_message(message["id"], "Tool Response", message['content'])

Expand All @@ -218,7 +228,7 @@ async def validate_model(self) -> None:
if not model:
await self.select_model()
if is_litellm_model(model):
model = self.ai.model
model = self.ai.model.replace(':', '/')
keys = validate_environment(model)
if keys['keys_in_environment'] is False:
await self.push_screen_wait(ModelConfigDialog(keys['missing_keys']))
Expand Down
4 changes: 2 additions & 2 deletions r2ai/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import signal
from r2ai.pipe import get_r2_inst
from r2ai.tools import run_python, r2cmd
from r2ai.tools import run_python, r2cmd, execute_binary
from r2ai.repl import r2ai_singleton
from r2ai.auto import ChatAuto, SYSTEM_PROMPT_AUTO
from r2ai.interpreter import is_litellm_model
Expand All @@ -14,7 +14,7 @@ def signal_handler(signum, frame):

async def chat(ai, message, cb):
model = ai.model.replace(":", "/")
tools = [r2cmd, run_python]
tools = [r2cmd, run_python, execute_binary]
messages = ai.messages + [{"role": "user", "content": message}]
tool_choice = 'auto'
if not is_litellm_model(model) and ai and not ai.llama_instance:
Expand Down

0 comments on commit 59f891f

Please sign in to comment.