From dcff9e31f6c43af545e1e71b3dfb95e6c9b6bc36 Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Wed, 4 Dec 2024 14:31:36 -0500 Subject: [PATCH] Dialectic Streaming Endpoint Fix (#79) * fix: dialectic endpoint stream method * chore: docs --- CHANGELOG.md | 6 ++++++ README.md | 2 +- pyproject.toml | 2 +- src/agent.py | 24 ++++++++++-------------- src/main.py | 2 +- src/routers/sessions.py | 7 +++++-- 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71dba66..0236b80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## [0.0.15] + +### Fixed + +- Dialectic Streaming Endpoint properly sends text in `StreamingResponse` + ## [0.0.14] — 2024-11-14 ### Changed diff --git a/README.md b/README.md index d79945a..8c76844 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # 🫡 Honcho -![Static Badge](https://img.shields.io/badge/Version-0.0.14-blue) +![Static Badge](https://img.shields.io/badge/Version-0.0.15-blue) [![Discord](https://img.shields.io/discord/1016845111637839922?style=flat&logo=discord&logoColor=23ffffff&label=Plastic%20Labs&labelColor=235865F2)](https://discord.gg/plasticlabs) [![arXiv](https://img.shields.io/badge/arXiv-2310.06983-b31b1b.svg)](https://arxiv.org/abs/2310.06983) ![GitHub License](https://img.shields.io/github/license/plastic-labs/honcho) diff --git a/pyproject.toml b/pyproject.toml index 35e20e5..20da08b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "honcho" -version = "0.0.14" +version = "0.0.15" description = "Honcho Server" authors = [ {name = "Plastic Labs", email = "hello@plasticlabs.ai"}, diff --git a/src/agent.py b/src/agent.py index 52195c5..881d160 100644 --- a/src/agent.py +++ b/src/agent.py @@ -2,7 +2,7 @@ import os from collections.abc import Iterable -from anthropic import Anthropic +from anthropic import Anthropic, MessageStreamManager from dotenv import load_dotenv from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -31,9 +31,7 @@ def get_set(self) -> set[str]: class Dialectic: - def __init__( - self, agent_input: str, user_representation: str, chat_history: list[str] - ): + def __init__(self, agent_input: str, user_representation: str, chat_history: str): self.agent_input = agent_input self.user_representation = user_representation self.chat_history = chat_history @@ -68,8 +66,7 @@ def stream(self): {self.chat_history} Provide a brief, matter-of-fact, and appropriate response to the query based on the context provided. If the context provided doesn't aid in addressing the query, return only the word "None". """ - - yield from self.client.messages.create( + return self.client.messages.stream( model="claude-3-5-sonnet-20240620", messages=[ { @@ -78,26 +75,25 @@ def stream(self): } ], max_tokens=300, - stream=True, ) -async def chat_history(app_id: str, user_id: str, session_id: str) -> list[str]: +async def chat_history(app_id: str, user_id: str, session_id: str) -> str: async with SessionLocal() as db: stmt = await crud.get_messages(db, app_id, user_id, session_id) results = await db.execute(stmt) messages = results.scalars() - history = [] + history = "" for message in messages: if message.is_user: - history.append(f"user:{message.content}") + history += f"user:{message.content}\n" else: - history.append(f"assistant:{message.content}") + history += f"assistant:{message.content}\n" return history async def get_latest_user_representation( - db: AsyncSession, app_id: str, user_id: str, session_id: str + db: AsyncSession, app_id: str, user_id: str ) -> str: stmt = ( select(models.Metamessage) @@ -126,13 +122,13 @@ async def chat( session_id: str, query: schemas.AgentQuery, stream: bool = False, -): +) -> schemas.AgentChat | MessageStreamManager: questions = [query.queries] if isinstance(query.queries, str) else query.queries final_query = "\n".join(questions) if len(questions) > 1 else questions[0] async with SessionLocal() as db: # Run user representation retrieval and chat history retrieval concurrently - user_rep_task = get_latest_user_representation(db, app_id, user_id, session_id) + user_rep_task = get_latest_user_representation(db, app_id, user_id) history_task = chat_history(app_id, user_id, session_id) # Wait for both tasks to complete diff --git a/src/main.py b/src/main.py index 58a282d..7c2cbc3 100644 --- a/src/main.py +++ b/src/main.py @@ -47,7 +47,7 @@ async def lifespan(app: FastAPI): summary="An API for adding personalization to AI Apps", description="""This API is used to store data and get insights about users for AI applications""", - version="0.0.14", + version="0.0.15", contact={ "name": "Plastic Labs", "url": "https://plasticlabs.ai", diff --git a/src/routers/sessions.py b/src/routers/sessions.py index 432b4a6..e2ff8a4 100644 --- a/src/routers/sessions.py +++ b/src/routers/sessions.py @@ -1,5 +1,6 @@ from typing import Optional +from anthropic import MessageStreamManager from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from fastapi_pagination import Page @@ -150,8 +151,10 @@ async def parse_stream(): query=query, stream=True, ) - for chunk in stream: - yield chunk.content + if type(stream) is MessageStreamManager: + with stream as stream_manager: + for text in stream_manager.text_stream: + yield text return StreamingResponse( content=parse_stream(), media_type="text/event-stream", status_code=200