Skip to content

Commit

Permalink
TUI: Ask for API KEY if one is not set in env
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov authored and trufae committed Sep 26, 2024
1 parent bc78ba6 commit dafd93d
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions r2ai/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from textual.widget import Widget
from textual.css.query import NoMatches
from textual import log
from litellm import validate_environment

# from ..repl import set_model, r2ai_singleton
# ai = r2ai_singleton()
Expand All @@ -33,7 +34,21 @@ def compose(self) -> ComposeResult:
def on_model_select_model_selected(self, event: ModelSelect.ModelSelected) -> None:
self.dismiss(event.model)

class ModelConfigDialog(ModalScreen):
def __init__(self, keys: list[str]) -> None:
super().__init__()
self.keys = keys

def compose(self) -> ComposeResult:
for key in self.keys:
yield Input(placeholder=key, id=f"{key}-input")
yield Button("Save", variant="primary", id="save-button")

def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "save-button":
for key in self.keys:
os.environ[key] = self.query_one(f"#{key}-input", Input).value
self.dismiss()
class ChatMessage(Widget):
markdown: reactive[str] = reactive("")
sender = "User"
Expand Down Expand Up @@ -127,12 +142,18 @@ def on_mount(self) -> None:
def action_show_command_palette(self) -> None:
self.push_screen("command_palette")

@work
async def action_select_model(self) -> None:

async def select_model(self) -> None:
model = await self.push_screen_wait(ModelSelectDialog())
if model:
await self.validate_model()
self.notify(f"Selected model: {get_env('model')}")
self.update_sub_title()
self.update_sub_title()

@work
async def action_select_model(self) -> None:
await self.select_model()

@work
async def action_load_binary(self) -> None:
binary = await self.push_screen_wait(BinarySelectDialog())
Expand All @@ -149,6 +170,7 @@ def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "send-button":
self.send_message()

@work
async def on_input_submitted(self, event: Input.Submitted) -> None:
if event.input.id == "chat-input":
await self.send_message()
Expand All @@ -174,10 +196,22 @@ async def send_message(self) -> None:
self.add_message(None, "User", message)
input_widget.value = ""
try:
await self.validate_model()
await chat(message, self.on_message)
except Exception as e:
self.notify(str(e), severity="error")

async def validate_model(self) -> None:
model = get_env("model")
if not model:
await self.select_model()
model = get_env("model")
keys = validate_environment(model)
if keys['keys_in_environment'] is False:
await self.push_screen_wait(ModelConfigDialog(keys['missing_keys']))

return True

def add_message(self, id: str, sender: str, content: str) -> None:
chat_container = self.query_one("#chat-container", VerticalScroll)
msg = ChatMessage(id, sender, content)
Expand All @@ -188,7 +222,6 @@ def add_message(self, id: str, sender: str, content: str) -> None:
def scroll_to_bottom(self) -> None:
chat_scroll = self.query_one("#chat-container", VerticalScroll)
chat_scroll.scroll_end(animate=False)

class Message(Widget):
def __init__(self, message: str) -> None:
super().__init__()
Expand Down

0 comments on commit dafd93d

Please sign in to comment.