Skip to content

Commit

Permalink
TUI updates and some error handling
Browse files Browse the repository at this point in the history
* ctrl+o to open binary, ctrl+m for model selector, error handling, some formatting updates
  • Loading branch information
dnakov authored Sep 25, 2024
1 parent 5acab2a commit e5445fd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 27 deletions.
72 changes: 48 additions & 24 deletions r2ai/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from textual.containers import ScrollableContainer, Container, Horizontal, VerticalScroll, Grid, Vertical # Add Vertical to imports
from textual.widgets import Header, Footer, Input, Button, Static, DirectoryTree, Label, Tree, Markdown
from textual.command import CommandPalette, Command, Provider, Hits, Hit
from textual.screen import Screen
from textual.screen import Screen, ModalScreen
from textual.message import Message
from textual.reactive import reactive
from .model_select import ModelSelect
Expand All @@ -14,6 +14,7 @@
from textual.widget import Widget
from textual.css.query import NoMatches
from textual import log

# from ..repl import set_model, r2ai_singleton
# ai = r2ai_singleton()
from .chat import chat, messages
Expand All @@ -25,7 +26,7 @@ async def search(self, query: str) -> Hits:
yield Hit("Select Model", "Select Model", self.action_select_model)


class ModelSelectDialog(Screen):
class ModelSelectDialog(ModalScreen):
def compose(self) -> ComposeResult:
yield Grid(ModelSelect(), id="model-select-dialog")

Expand All @@ -42,18 +43,20 @@ def __init__(self, id: str, sender: str, content: str, **kwargs) -> None:
# self.markdown = f"*{sender}*: {content}"
self.markdown = content
self.sender = sender

async def watch_markdown(self, markdown: str) -> None:
mkd = self.query_exactly_one(".text1")
print(self.markdown)
text = self.markdown
if hasattr(mkd, 'update') and callable(mkd.update):
update_method = mkd.update
if asyncio.iscoroutinefunction(update_method):
await update_method(text)
else:
update_method(text)

try:
mkd = self.query_exactly_one(".text1")
text = self.markdown
if hasattr(mkd, 'update') and callable(mkd.update):
update_method = mkd.update
if asyncio.iscoroutinefunction(update_method):
await update_method(text)
else:
update_method(text)
except NoMatches:
pass

def add_text(self, markdown: str) -> None:
self.markdown += markdown

Expand All @@ -80,8 +83,23 @@ class R2AIApp(App):
CSS_PATH = "app.tcss"
BINDINGS = [
("ctrl+p", "show_command_palette", "Command Palette"),
("ctrl+m", "select_model", "Select Model"),
("ctrl+o", "load_binary", "Load Binary")
]
TITLE = "r2ai"
SUB_TITLE = reactive(get_env('model'))

def update_sub_title(self, binary: str = None) -> str:
sub_title = None
model = get_env('model')
if binary and model:
binary = Path(binary).name
sub_title = f"{model} | {binary}"
elif binary:
sub_title = binary
else:
sub_title = model
self.sub_title = sub_title

def compose(self) -> ComposeResult:
yield Header()
Expand All @@ -108,14 +126,19 @@ def on_mount(self) -> None:

def action_show_command_palette(self) -> None:
self.push_screen("command_palette")

def action_select_model(self) -> None:
model = self.push_screen(ModelSelectDialog())

@work
async def action_select_model(self) -> None:
model = await self.push_screen_wait(ModelSelectDialog())
if model:
self.notify(f"Selected model: {get_env('model')}")

def action_load_binary(self) -> None:
self.push_screen(BinarySelectDialog())
self.update_sub_title()
@work
async def action_load_binary(self) -> None:
binary = await self.push_screen_wait(BinarySelectDialog())
if binary:
self.notify(f"Selected binary: {binary}")
self.update_sub_title(binary)

def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]:
yield from super().get_system_commands(screen)
Expand All @@ -125,8 +148,6 @@ def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]:
def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "send-button":
self.send_message()
def on_model_select_model_selected(self, event: ModelSelect.ModelSelected) -> None:
self.notify(f"Selected model: {event.model}")

async def on_input_submitted(self, event: Input.Submitted) -> None:
if event.input.id == "chat-input":
Expand All @@ -142,7 +163,7 @@ def on_message(self, type: str, message: any) -> None:
except NoMatches:
existing = self.add_message(message["id"], "AI", message["content"])
elif type == 'tool_call':
self.add_message(message["id"], "Tool Call", message['function']['name'])
self.add_message(message["id"], "Tool Call", f"{message['function']['name']} > {message['function']['arguments']['command']}")
elif type == 'tool_response':
self.add_message(message["id"], "Tool Response", message['content'])

Expand All @@ -152,7 +173,10 @@ async def send_message(self) -> None:
if message:
self.add_message(None, "User", message)
input_widget.value = ""
await chat(message, self.on_message)
try:
await chat(message, self.on_message)
except Exception as e:
self.notify(str(e), severity="error")

def add_message(self, id: str, sender: str, content: str) -> None:
chat_container = self.query_one("#chat-container", VerticalScroll)
Expand Down Expand Up @@ -181,7 +205,7 @@ def compose(self) -> ComposeResult:
yield Message(message)


class BinarySelectDialog(Screen):
class BinarySelectDialog(ModalScreen):
BINDINGS = [
("up", "cursor_up", "Move cursor up"),
("down", "cursor_down", "Move cursor down"),
Expand Down
7 changes: 4 additions & 3 deletions r2ai/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ async def process_tool_calls(tool_calls, cb):
for tool_call in tool_calls:
tool_name = tool_call["function"]["name"]
tool_args = json.loads(tool_call["function"]["arguments"])
if cb:
cb('tool_call', { "id": tool_call["id"], "function": { "name": tool_name, "arguments": tool_args } })
if tool_name == "r2cmd":
res = r2cmd(tool_args["command"])
messages.append({"role": "tool", "name": tool_name, "content": res + tool_end_message, "tool_call_id": tool_call["id"]})
Expand Down Expand Up @@ -121,9 +123,6 @@ async def process_streaming_response(resp, cb):
}
}
)
if cb:
cb('tool_call', tool_calls[index])
print(tool_calls)

# handle some bug in llama-cpp-python streaming, tool_call.arguments is sometimes blank, but function_call has it.
# if fn_delta.arguments == '':
Expand Down Expand Up @@ -160,5 +159,7 @@ async def get_completion(cb):

async def chat(message: str, cb) -> str:
messages.append({"role": "user", "content": message})
if not get_env("model"):
raise Exception("No model selected")
response = await get_completion(cb)
return response

0 comments on commit e5445fd

Please sign in to comment.