Skip to content

Commit

Permalink
Honcho Changes (#77)
Browse files Browse the repository at this point in the history
* init nextjs

* fast api init

* styling and thoughts

* streaming updates

* connect to api

* Add thoughts to the web UI

* Refactor input to be a form for UX (e.g. pressing enter sends)

* typing and thoughts

* Refactor input to be a form for UX (e.g. pressing enter sends)

* Revert "Merge remote-tracking branch 'origin/custom-web' into custom-web"

This reverts commit 1eae747.

* Skeleton Multiple Chat Window UI

* MVP Layout

* Tested Discord and Skeleton FastAPI

* Add, Delete, and Set Conversations

* Get and send messages

* Edit message names

* Local serving from FastAPI via static export

* Deployment strategy for static files

* Separate out apps

* Vercel Deployment with Action

* Re-add discord to fly.toml

* Honcho Stream

* Honcho Stream

---------

Co-authored-by: hyusap <paulayush@gmail.com>
Co-authored-by: Jacob Van Meter <jacobvm04@gmail.com>
  • Loading branch information
3 people authored Sep 11, 2023
1 parent 638c78a commit 0596350
Show file tree
Hide file tree
Showing 8 changed files with 816 additions and 1,517 deletions.
61 changes: 12 additions & 49 deletions agent/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from collections import OrderedDict
from .mediator import SupabaseMediator
import uuid
from typing import List, Tuple
from typing import List, Tuple, Dict
from langchain.schema import BaseMessage

class Conversation:
"Wrapper Class for storing contexts between channels. Using an object to pass by reference avoid additional cache hits"
def __init__(self, mediator: SupabaseMediator, user_id: str, conversation_id: str = str(uuid.uuid4()), location_id: str = "web"):
def __init__(self, mediator: SupabaseMediator, user_id: str, conversation_id: str = str(uuid.uuid4()), location_id: str = "web", metadata: Dict = {}):
self.mediator: SupabaseMediator = mediator
self.user_id: str = user_id
self.conversation_id: str = conversation_id
self.location_id: str = location_id
self.metadata: Dict = metadata

def add_message(self, message_type: str, message: BaseMessage,) -> None:
self.mediator.add_message(self.conversation_id, self.user_id, message_type, message)
Expand All @@ -27,7 +28,9 @@ def delete(self) -> None:

def restart(self) -> None:
self.delete()
self.conversation_id: str = self.mediator.add_conversation(self.location_id, self.user_id)
representation = self.mediator.add_conversation(user_id=self.user_id, location_id=self.location_id)
self.conversation_id: str = representation["id"]
self.metadata = representation["metadata"]


class LRUCache:
Expand Down Expand Up @@ -66,27 +69,28 @@ def get(self, user_id: str, location_id: str) -> None | Conversation:
key = location_id+user_id
if key in self.memory_cache:
return self.memory_cache[key]

conversation = self.mediator.conversations(location_id=location_id, user_id=user_id)
if conversation:
conversation_id = conversation[0]["id"]
metadata = conversation[0]["metadata"]
# Add the conversation data to the memory_cache
if len(self.memory_cache) >= self.capacity:
self.memory_cache.popitem(last=False)
self.memory_cache[key] = Conversation(self.mediator, location_id=location_id, user_id=user_id, conversation_id=conversation_id)
self.memory_cache[key] = Conversation(self.mediator, location_id=location_id, user_id=user_id, conversation_id=conversation_id, metadata=metadata)
return self.memory_cache[key]

return None

def put(self, user_id: str, location_id: str) -> Conversation:
# Add the conversation data to the postgres via the mediator
conversation_id = self.mediator.add_conversation(location_id=location_id, user_id=user_id)
representation: Dict = self.mediator.add_conversation(location_id=location_id, user_id=user_id)
conversation_id = representation["id"]
metadata = representation["metadata"]
key: str = location_id+user_id

if len(self.memory_cache) >= self.capacity:
# Remove the least recently used item from the memory cache
self.memory_cache.popitem(last=False)
self.memory_cache[key] = Conversation(self.mediator, location_id=location_id, user_id=user_id, conversation_id=conversation_id)
self.memory_cache[key] = Conversation(self.mediator, location_id=location_id, user_id=user_id, conversation_id=conversation_id, metadata=metadata)
return self.memory_cache[key]

def get_or_create(self, user_id: str, location_id: str, restart: bool = False) -> Conversation:
Expand All @@ -96,44 +100,3 @@ def get_or_create(self, user_id: str, location_id: str, restart: bool = False) -
elif restart:
cache.restart()
return cache

# class LayeredLRUConversationCache:
# """A Conversation LRU Cache that bases keys on the conversation_id of a conversation. The assumption is that the conversation is the unique identifier"""
# def __init__(self, capacity, mediator: SupabaseMediator):
# self.capacity = capacity
# self.memory_cache = OrderedDict()
# self.mediator = mediator
#
# def get(self, user_id: str, conversation_id: str) -> None | Conversation:
# key = conversation_id+user_id
# if key in self.memory_cache:
# return self.memory_cache[key]
#
# location_id = self.mediator.conversation(conversation_id)
# if location_id:
# # Add the conversation data to the memory_cache
# if len(self.memory_cache) >= self.capacity:
# self.memory_cache.popitem(last=False)
# self.memory_cache[key] = Conversation(self.mediator, location_id=location_id, user_id=user_id, conversation_id=conversation_id)
# return self.memory_cache[key]
#
# return None
#
# def put(self, user_id: str, location_id: str) -> Conversation:
# # Add the conversation data to the postgres via the mediator
# conversation_id = self.mediator.add_conversation(location_id=location_id, user_id=user_id)
# key: str = conversation_id+user_id
#
# if len(self.memory_cache) >= self.capacity:
# # Remove the least recently used item from the memory cache
# self.memory_cache.popitem(last=False)
# self.memory_cache[key] = Conversation(self.mediator, location_id=location_id, user_id=user_id, conversation_id=conversation_id)
# return self.memory_cache[key]
#
#
# def hard_delete(self, user_id: str, conversation_id: str) -> None:
# key = conversation_id+user_id
# if key in self.memory_cache:
# self.memory_cache.pop(key)
# self.mediator.delete_conversation(conversation_id)
#
37 changes: 26 additions & 11 deletions agent/mediator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from supabase.client import create_client, Client
from typing import List, Tuple, Dict
import json
import random
load_dotenv()

class SupabaseMediator:
Expand All @@ -29,9 +30,15 @@ def messages(self, session_id: str, user_id: str, message_type: str, limit: Tupl
return messages[::-1]

def add_message(self, session_id: str, user_id: str, message_type: str, message: BaseMessage) -> None:
self.supabase.table(self.memory_table).insert({"session_id": session_id, "user_id": user_id, "message_type": message_type, "message": _message_to_dict(message)}).execute()

def conversations(self, location_id: str, user_id: str, single: bool = True, metadata=False) -> List[Dict] | None:
payload = {
"session_id": session_id,
"user_id": user_id,
"message_type": message_type,
"message": _message_to_dict(message)
}
self.supabase.table(self.memory_table).insert(payload).execute()

def conversations(self, location_id: str, user_id: str, single: bool = True) -> List[Dict] | None:
try:
response = self.supabase.table(self.conversation_table).select(*["id", "metadata"], count="exact").eq("location_id", location_id).eq("user_id", user_id).eq("isActive", True).order("created_at", desc=True).execute()
if response is not None and response.count is not None:
Expand All @@ -50,21 +57,29 @@ def conversations(self, location_id: str, user_id: str, single: bool = True, met
return None


def conversation(self, session_id: str) -> str | None:
response = self.supabase.table(self.conversation_table).select("location_id").eq("id", session_id).eq("isActive", True).maybe_single().execute()
def conversation(self, session_id: str) -> Dict | None:
response = self.supabase.table(self.conversation_table).select("*").eq("id", session_id).eq("isActive", True).maybe_single().execute()
if response:
location_id = response.data["location_id"]
return location_id
return response.data
return None

def _cleanup_conversations(self, conversation_ids: List[str]) -> None:
for conversation_id in conversation_ids:
self.supabase.table(self.conversation_table).update({"isActive": False}).eq("id", conversation_id).execute()

def add_conversation(self, location_id: str, user_id: str) -> str:
def add_conversation(self, location_id: str, user_id: str) -> Dict:
conversation_id = str(uuid.uuid4())
self.supabase.table(self.conversation_table).insert({"id": conversation_id, "user_id": user_id, "location_id": location_id}).execute()
return conversation_id
payload = {
"id": conversation_id,
"user_id": user_id,
"location_id": location_id,
"metadata": {"A/B": bool(random.getrandbits(1))}
}
representation = self.supabase.table(self.conversation_table).insert(payload, returning="representation").execute() # type: ignore
print("========================================")
print(representation)
print("========================================")
return representation.data[0]

def delete_conversation(self, conversation_id: str) -> None:
self.supabase.table(self.conversation_table).update({"isActive": False}).eq("id", conversation_id).execute()
Expand All @@ -76,7 +91,7 @@ def update_conversation(self, conversation_id: str, metadata: Dict) -> None:
new_metadata.update(metadata)
else:
new_metadata = metadata
self.supabase.table(self.conversation_table).update({"metadata": new_metadata}, returning="representation").eq("id", conversation_id).execute()
self.supabase.table(self.conversation_table).update({"metadata": new_metadata}, returning="representation").eq("id", conversation_id).execute() # type: ignore


# Modification of PostgresChatMessageHistory: https://api.python.langchain.com/en/latest/_modules/langchain/memory/chat_message_histories/postgres.html#PostgresChatMessageHistory
Expand Down
39 changes: 31 additions & 8 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests

from common import init
from agent.chain import BloomChain
Expand Down Expand Up @@ -68,7 +69,8 @@ async def delete_conversation(conversation_id: str):
@app.get("/api/conversations/insert")
async def add_conversation(user_id: str, location_id: str = "web"):
async with LOCK:
conversation_id = MEDIATOR.add_conversation(location_id=location_id, user_id=user_id)
representation = MEDIATOR.add_conversation(location_id=location_id, user_id=user_id)
conversation_id = representation["id"]
return {
"conversation_id": conversation_id
}
Expand All @@ -92,6 +94,17 @@ async def get_messages(user_id: str, conversation_id: str):
async def chat(inp: ConversationInput):
async with LOCK:
conversation = Conversation(MEDIATOR, user_id=inp.user_id, conversation_id=inp.conversation_id)
conversation_data = MEDIATOR.conversation(session_id=inp.conversation_id)
if conversation_data and conversation_data["metadata"]:
metadata = conversation_data["metadata"]
if metadata["A/B"]:
response = requests.post(f'{os.environ["HONCHO_URL"]}/chat', json={
"user_id": inp.user_id,
"conversation_id": inp.conversation_id,
"message": inp.message
}, stream=True)
print(response)
return response
if conversation is None:
raise HTTPException(status_code=404, detail="Item not found")
thought, response = await BloomChain.chat(conversation, inp.message)
Expand All @@ -104,14 +117,25 @@ async def chat(inp: ConversationInput):
async def stream(inp: ConversationInput):
async with LOCK:
conversation = Conversation(MEDIATOR, user_id=inp.user_id, conversation_id=inp.conversation_id)
conversation_data = MEDIATOR.conversation(session_id=inp.conversation_id)
if conversation_data and conversation_data["metadata"]:
metadata = conversation_data["metadata"]
if metadata["A/B"]:
response = requests.post(f'{os.environ["HONCHO_URL"]}/stream', json={
"user_id": inp.user_id,
"conversation_id": inp.conversation_id,
"message": inp.message
}, stream=True)

def generator():
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk

print("A/B Confirmed")
return StreamingResponse(generator())
if conversation is None:
raise HTTPException(status_code=404, detail="Item not found")
print()
print()
print("local chain", conversation.messages("thought"), conversation.messages("response"))
print()
print()


async def thought_and_response():
thought_iterator = BloomChain.think(conversation, inp.message)
Expand All @@ -130,4 +154,3 @@ async def thought_and_response():

return StreamingResponse(thought_and_response())

# app.mount("/", StaticFiles(directory="www/out", html=True), name="static")
22 changes: 18 additions & 4 deletions bot/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# core functionality

import discord
import os
from __main__ import (
CACHE,
LOCK,
Expand All @@ -10,6 +11,7 @@
from typing import Optional
from agent.chain import BloomChain
from langchain.schema import AIMessage, HumanMessage, BaseMessage
import httpx


class Core(commands.Cog):
Expand Down Expand Up @@ -44,9 +46,10 @@ async def on_message(self, message):
if message.author == self.bot.user:
return

user_id = f"discord_{str(message.author.id)}"
# Get cache for conversation
async with LOCK:
CONVERSATION = CACHE.get_or_create(location_id=str(message.channel.id), user_id=f"discord_{str(message.author.id)}")
CONVERSATION = CACHE.get_or_create(location_id=str(message.channel.id), user_id=user_id)

# Get the message content but remove any mentions
inp = message.content.replace(str('<@' + str(self.bot.user.id) + '>'), '')
Expand All @@ -55,7 +58,20 @@ async def on_message(self, message):
async def respond(reply = True, forward_thought = True):
"Generate response too user"
async with message.channel.typing():
thought, response = await BloomChain.chat(CONVERSATION, inp)
thought = ""
response = ""
if (CONVERSATION.metadata is not None and "A/B" in CONVERSATION.metadata and CONVERSATION.metadata["A/B"] == True):
async with httpx.AsyncClient() as client:
response = await client.post(f'{os.environ["HONCHO_URL"]}/chat', json={
"user_id": CONVERSATION.user_id,
"conversation_id": CONVERSATION.conversation_id,
"message": inp
}, timeout=None)
response_text = response.json()
thought = response_text["thought"]
response = response_text["response"]
else:
thought, response = await BloomChain.chat(CONVERSATION, inp)

# sanitize thought by adding zero width spaces to triple backticks
thought = thought.replace("```", "`\u200b`\u200b`")
Expand Down Expand Up @@ -157,7 +173,5 @@ async def restart(self, ctx: discord.ApplicationContext, respond: Optional[bool]
return




def setup(bot):
bot.add_cog(Core(bot))
Loading

0 comments on commit 0596350

Please sign in to comment.