Skip to content

Commit

Permalink
fix some auto streaming bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov committed Nov 7, 2024
1 parent 096987e commit fab23ef
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 22 deletions.
60 changes: 39 additions & 21 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"""

class ChatAuto:
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, cb=None ):
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, stream=True, cb=None ):
self.logger = LOGGER
self.functions = {}
self.tools = []
Expand All @@ -52,6 +52,7 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp
self.interpreter = interpreter
self.system_message = None
self.timeout = timeout
self.stream = stream
if messages and messages[0]['role'] != 'system' and system:
self.messages.insert(0, { "role": "system", "content": system })
if cb:
Expand All @@ -66,7 +67,6 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp
self.functions[f['name']] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None

#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'

async def process_tool_calls(self, tool_calls):
Expand Down Expand Up @@ -186,31 +186,37 @@ async def async_response_generator(self, response):
resp = ModelResponse(stream=True, **item)
yield resp

async def attempt_completion(self, stream=True):
args = {
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
"stream": stream,
}

async def attempt_completion(self):
stream = self.stream
if self.llama_instance:
args = {
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
"stream": stream,
}
res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args)
if args['stream']:
return self.async_response_generator(res)
else:
return ModelResponse(**next(res))

self.logger.debug('chat completion')
return await acompletion(
model=self.model,
messages=self.messages,
timeout=self.timeout,
**args
tools=self.tools,
tool_choice=self.tool_choice,
temperature=self.temperature,
top_p=self.top_p,
max_tokens=self.max_tokens,
stream=stream,
)

async def get_completion(self, stream=False):
async def get_completion(self):
stream = self.stream
if self.llama_instance:
response = await self.attempt_completion(stream=stream)
response = await self.attempt_completion()
if stream:
return await self.process_streaming_response(response)
else:
Expand All @@ -220,7 +226,8 @@ async def get_completion(self, stream=False):

for retry_count in range(max_retries):
try:
response = await self.attempt_completion(stream=stream)
response = await self.attempt_completion()
self.logger.debug(f'chat completion {response}')
if stream:
return await self.process_streaming_response(response)
else:
Expand All @@ -236,10 +243,11 @@ async def get_completion(self, stream=False):

raise Exception("Max retries reached. Unable to get completion.")

async def achat(self, messages=None, stream=False) -> str:
async def achat(self, messages=None) -> str:
if messages:
self.messages = messages
response = await self.get_completion(stream)
response = await self.get_completion()
self.logger.debug(f'chat complete')
return response

def chat(self, **kwargs) -> str:
Expand Down Expand Up @@ -289,16 +297,26 @@ def chat(interpreter, **kwargs):
try:
signal.signal(signal.SIGINT, signal_handler)
spinner.start()
return loop.run_until_complete(chat_auto.achat(stream=True))
return loop.run_until_complete(chat_auto.achat())
except KeyboardInterrupt:
builtins.print("\033[91m\nOperation cancelled by user.\033[0m")
tasks = asyncio.all_tasks(loop=loop)
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
try:
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
loop.run_until_complete(asyncio.sleep(0.1))
except asyncio.CancelledError:
pass
return None
finally:
signal.signal(signal.SIGINT, original_handler)
spinner.stop()
loop.stop()
loop.close()
try:
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
2 changes: 1 addition & 1 deletion r2ai/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ async def chat(ai, message, cb):

chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)

return await chat_auto.achat(stream=True)
return await chat_auto.achat()

0 comments on commit fab23ef

Please sign in to comment.