Skip to content

Commit

Permalink
Start of frontend changes
Browse files Browse the repository at this point in the history
Time travel, use checkpoint as primary source of truth

Refactor state management for chat window

Add support for state graph

Fixes

Pare down unneeded functionality, frontend updates

Fix repeated history fetches

Add basic state graph support, many other fixes

Revise state graph time travel flow

Use message graph as default

Fix flashing messages in UI on send

Allow adding and deleting tool calls

Hacks!

Only accept module paths

More logs

add env

add built ui files

Build ui files

Update cli

Delete .github/workflows/build_deploy_image.yml

Update path

Update ui files

Move migrations

Move ui files

0.0.5

Allow resume execution for tool messages (#2)

Undo

Undo

Remove cli

Undo

Undo

Update storage/threads

Undo ui

Undo

Lint

Undo

Rm

Undo

Rm

Update api

Undo

WIP
  • Loading branch information
jacoblee93 authored and nfcampos committed Apr 15, 2024
1 parent cb39b9b commit 0dab4d6
Show file tree
Hide file tree
Showing 12 changed files with 678 additions and 25 deletions.
7 changes: 5 additions & 2 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Annotated, Any, Dict, List, Sequence, Union
from typing import Annotated, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
Expand Down Expand Up @@ -27,6 +27,7 @@ class ThreadPostRequest(BaseModel):
"""Payload for adding state to a thread."""

values: Union[Sequence[AnyMessage], Dict[str, Any]]
config: Optional[Dict[str, Any]] = None


@router.get("/")
Expand Down Expand Up @@ -60,7 +61,9 @@ async def add_thread_state(
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return await storage.update_thread_state(user["user_id"], tid, payload.values)
return await storage.update_thread_state(
payload.config or {"configurable": {"thread_id": tid}}, payload.values
)


@router.get("/{tid}/history")
Expand Down
12 changes: 6 additions & 6 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig

from app.agent import AgentType, get_agent_executor
from app.lifespan import get_pg_pool
Expand Down Expand Up @@ -109,26 +110,25 @@ async def get_thread_state(user_id: str, thread_id: str):


async def update_thread_state(
user_id: str, thread_id: str, values: Union[Sequence[AnyMessage], Dict[str, Any]]
config: RunnableConfig, values: Union[Sequence[AnyMessage], dict[str, Any]]
):
"""Add state to a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
await app.aupdate_state({"configurable": {"thread_id": thread_id}}, values)
return await app.aupdate_state(config, values)


async def get_thread_history(user_id: str, thread_id: str):
"""Get the history of a thread."""
app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
config = {"configurable": {"thread_id": thread_id}}
return [
{
"values": c.values,
"next": c.next,
"config": c.config,
"parent": c.parent_config,
}
async for c in app.aget_state_history(
{"configurable": {"thread_id": thread_id}}
)
async for c in app.aget_state_history(config)
]


Expand Down
4 changes: 4 additions & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"name": "frontend",
"private": true,
"version": "0.0.0",
"packageManager": "yarn@1.22.19",
"type": "module",
"scripts": {
"dev": "vite --host",
Expand All @@ -11,9 +12,12 @@
"format": "prettier -w src"
},
"dependencies": {
"@emotion/react": "^11.11.4",
"@emotion/styled": "^11.11.0",
"@headlessui/react": "^1.7.17",
"@heroicons/react": "^2.0.18",
"@microsoft/fetch-event-source": "^2.0.1",
"@mui/material": "^5.15.14",
"@tailwindcss/forms": "^0.5.6",
"@tailwindcss/typography": "^0.5.10",
"clsx": "^2.0.0",
Expand Down
7 changes: 6 additions & 1 deletion frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ function App(props: { edit?: boolean }) {
const { currentChat, assistantConfig, isLoading } = useThreadAndAssistant();

const startTurn = useCallback(
async (message: MessageWithFiles | null, thread_id: string) => {
async (
message: MessageWithFiles | null,
thread_id: string,
config?: Record<string, unknown>,
) => {
const files = message?.files || [];
if (files.length > 0) {
const formData = files.reduce((formData, file) => {
Expand Down Expand Up @@ -55,6 +59,7 @@ function App(props: { edit?: boolean }) {
]
: null,
thread_id,
config,
);
},
[startStream],
Expand Down
21 changes: 21 additions & 0 deletions frontend/src/assets/EmptyState.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions frontend/src/components/AutosizeTextarea.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { Ref } from "react";
import { cn } from "../utils/cn";

const COMMON_CLS = cn(
"text-sm col-[1] row-[1] m-0 resize-none overflow-hidden whitespace-pre-wrap break-words bg-transparent px-2 py-1 rounded shadow-none",
);

export function AutosizeTextarea(props: {
id?: string;
inputRef?: Ref<HTMLTextAreaElement>;
value?: string | null | undefined;
placeholder?: string;
className?: string;
onChange?: (e: string) => void;
onFocus?: () => void;
onBlur?: () => void;
onKeyDown?: (e: React.KeyboardEvent<HTMLTextAreaElement>) => void;
autoFocus?: boolean;
readOnly?: boolean;
cursorPointer?: boolean;
disabled?: boolean;
fullHeight?: boolean;
}) {
return (
<div
className={
cn("grid w-full", props.className) +
(props.fullHeight ? "" : " max-h-80 overflow-auto ")
}
>
<textarea
ref={props.inputRef}
id={props.id}
className={cn(
COMMON_CLS,
"text-transparent caret-black rounded focus:outline-0 focus:ring-0",
)}
disabled={props.disabled}
value={props.value ?? ""}
rows={1}
onChange={(e) => {
const target = e.target as HTMLTextAreaElement;
props.onChange?.(target.value);
}}
onFocus={props.onFocus}
onBlur={props.onBlur}
placeholder={props.placeholder}
readOnly={props.readOnly}
autoFocus={props.autoFocus && !props.readOnly}
onKeyDown={props.onKeyDown}
/>
<div
aria-hidden
className={cn(COMMON_CLS, "pointer-events-none select-none")}
>
{props.value}{" "}
</div>
</div>
);
}
14 changes: 11 additions & 3 deletions frontend/src/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ import { ArrowDownCircleIcon } from "@heroicons/react/24/outline";
import { MessageWithFiles } from "../utils/formTypes.ts";
import { useParams } from "react-router-dom";
import { useThreadAndAssistant } from "../hooks/useThreadAndAssistant.ts";
// import { useHistories } from "../hooks/useHistories.ts";
// import { Timeline } from "./Timeline.tsx";
// import { deepEquals } from "../utils/equals.ts";

interface ChatProps extends Pick<StreamStateProps, "stream" | "stopStream"> {
interface ChatProps
extends Pick<
StreamStateProps,
"stream" | "stopStream" | "streamErrorMessage"
> {
startStream: (
message: MessageWithFiles | null,
thread_id: string,
Expand Down Expand Up @@ -66,9 +73,10 @@ export function Chat(props: ChatProps) {
...
</div>
)}
{props.stream?.status === "error" && (
{(props.streamErrorMessage || props.stream?.status === "error") && (
<div className="flex items-center rounded-md bg-yellow-50 px-2 py-1 text-xs font-medium text-yellow-800 ring-1 ring-inset ring-yellow-600/20">
An error has occurred. Please try again.
{props.streamErrorMessage ??
"An error has occurred. Please try again."}
</div>
)}
{next.length > 0 && props.stream?.status !== "inflight" && (
Expand Down
59 changes: 59 additions & 0 deletions frontend/src/components/Timeline.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { useState } from "react";
import { Slider } from "@mui/material";
import {
ChevronLeftIcon,
ChevronRightIcon,
ClockIcon,
} from "@heroicons/react/24/outline";
import { cn } from "../utils/cn";
import { History } from "../hooks/useHistories";

export function Timeline(props: {
disabled: boolean;
histories: History[];
activeHistoryIndex: number;
onChange?: (newValue: number) => void;
}) {
const [expanded, setExpanded] = useState(false);
return (
<div className="flex items-center">
<button
className="flex items-center p-2 text-sm"
type="submit"
disabled={props.disabled}
onClick={() => setExpanded((expanded) => !expanded)}
>
<ClockIcon className="h-4 w-4 mr-1 rounded-md shrink-0" />
<span className="text-gray-900 font-semibold shrink-0">
Time travel
</span>
</button>
<Slider
className={cn(
"w-full shrink transition-max-width duration-200",
expanded ? " ml-8 mr-8 max-w-full" : " invisible max-w-0",
)}
aria-label="Timeline"
value={props.activeHistoryIndex}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
onChange={(e) => props.onChange?.((e.target as any).value)}
valueLabelDisplay="auto"
step={1}
marks
min={0}
max={props.histories.length - 1}
/>
{expanded ? (
<ChevronLeftIcon
className="h-4 w-4 cursor-pointer shrink-0 mr-4"
onClick={() => setExpanded((expanded) => !expanded)}
/>
) : (
<ChevronRightIcon
className="h-4 w-4 cursor-pointer shrink-0 mr-4"
onClick={() => setExpanded((expanded) => !expanded)}
/>
)}
</div>
);
}
41 changes: 41 additions & 0 deletions frontend/src/hooks/useHistories.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { useEffect, useState } from "react";
import { Message as MessageType } from "./useChatList";
import { StreamState } from "./useStreamState";

async function getHistories(threadId: string) {
const response = await fetch(`/threads/${threadId}/history`, {
headers: {
Accept: "application/json",
},
}).then((r) => r.json());
return response;
}

export interface History {
values: MessageType[];
next: string[];
config: Record<string, unknown>;
}

export function useHistories(
threadId: string | null,
stream: StreamState | null,
): {
histories: History[];
setHistories: React.Dispatch<React.SetStateAction<History[]>>;
} {
const [histories, setHistories] = useState<History[]>([]);

useEffect(() => {
async function fetchHistories() {
if (threadId) {
const histories = await getHistories(threadId);
setHistories(histories);
}
}
fetchHistories();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [threadId, stream?.status]);

return { histories, setHistories };
}
28 changes: 25 additions & 3 deletions frontend/src/hooks/useStreamState.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,30 @@ export interface StreamState {

export interface StreamStateProps {
stream: StreamState | null;
startStream: (input: Message[] | null, thread_id: string) => Promise<void>;
startStream: (
input: Message[] | null,
thread_id: string,
config?: Record<string, unknown>,
) => Promise<void>;
stopStream?: (clear?: boolean) => void;
setStreamStateStatus: (status: "inflight" | "error" | "done") => void;
streamErrorMessage: string | null;
setStreamErrorMessage: (message: string | null) => void;
}

export function useStreamState(): StreamStateProps {
const [current, setCurrent] = useState<StreamState | null>(null);
const [controller, setController] = useState<AbortController | null>(null);
const [streamErrorMessage, setStreamErrorMessage] = useState<string | null>(
null,
);

const startStream = useCallback(
async (input: Message[] | null, thread_id: string) => {
async (
input: Message[] | null,
thread_id: string,
config?: Record<string, unknown>,
) => {
const controller = new AbortController();
setController(controller);
setCurrent({ status: "inflight", messages: input || [] });
Expand All @@ -28,7 +42,7 @@ export function useStreamState(): StreamStateProps {
signal: controller.signal,
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ input, thread_id }),
body: JSON.stringify({ input, thread_id, config }),
openWhenHidden: true,
onmessage(msg) {
if (msg.event === "data") {
Expand All @@ -51,6 +65,7 @@ export function useStreamState(): StreamStateProps {
messages: current?.messages,
run_id: current?.run_id,
}));
setStreamErrorMessage("Error received while streaming output.");
}
},
onclose() {
Expand All @@ -67,6 +82,7 @@ export function useStreamState(): StreamStateProps {
messages: current?.messages,
run_id: current?.run_id,
}));
setStreamErrorMessage("Error in stream.");
setController(null);
throw error;
},
Expand Down Expand Up @@ -95,10 +111,16 @@ export function useStreamState(): StreamStateProps {
[controller],
);

const setStreamStateStatus = (value: "inflight" | "error" | "done") =>
setCurrent((current) => ({ ...current, status: value }));

return {
startStream,
stopStream,
stream: current,
setStreamStateStatus,
streamErrorMessage,
setStreamErrorMessage,
};
}

Expand Down
Loading

0 comments on commit 0dab4d6

Please sign in to comment.