Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix some auto streaming bugs #88

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"""

class ChatAuto:
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, cb=None ):
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, stream=True, cb=None ):
self.logger = LOGGER
self.functions = {}
self.tools = []
Expand All @@ -52,6 +52,7 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp
self.interpreter = interpreter
self.system_message = None
self.timeout = timeout
self.stream = stream
if messages and messages[0]['role'] != 'system' and system:
self.messages.insert(0, { "role": "system", "content": system })
if cb:
Expand All @@ -66,7 +67,6 @@ def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interp
self.functions[f['name']] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None

#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'

async def process_tool_calls(self, tool_calls):
Expand Down Expand Up @@ -186,31 +186,37 @@ async def async_response_generator(self, response):
resp = ModelResponse(stream=True, **item)
yield resp

async def attempt_completion(self, stream=True):
args = {
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
"stream": stream,
}

async def attempt_completion(self):
stream = self.stream
if self.llama_instance:
args = {
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
"stream": stream,
}
res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args)
if args['stream']:
return self.async_response_generator(res)
else:
return ModelResponse(**next(res))

self.logger.debug('chat completion')
return await acompletion(
model=self.model,
messages=self.messages,
timeout=self.timeout,
**args
tools=self.tools,
tool_choice=self.tool_choice,
temperature=self.temperature,
top_p=self.top_p,
max_tokens=self.max_tokens,
stream=stream,
)

async def get_completion(self, stream=False):
async def get_completion(self):
stream = self.stream
if self.llama_instance:
response = await self.attempt_completion(stream=stream)
response = await self.attempt_completion()
if stream:
return await self.process_streaming_response(response)
else:
Expand All @@ -220,7 +226,8 @@ async def get_completion(self, stream=False):

for retry_count in range(max_retries):
try:
response = await self.attempt_completion(stream=stream)
response = await self.attempt_completion()
self.logger.debug(f'chat completion {response}')
if stream:
return await self.process_streaming_response(response)
else:
Expand All @@ -236,10 +243,11 @@ async def get_completion(self, stream=False):

raise Exception("Max retries reached. Unable to get completion.")

async def achat(self, messages=None, stream=False) -> str:
async def achat(self, messages=None) -> str:
if messages:
self.messages = messages
response = await self.get_completion(stream)
response = await self.get_completion()
self.logger.debug(f'chat complete')
return response

def chat(self, **kwargs) -> str:
Expand Down Expand Up @@ -289,16 +297,26 @@ def chat(interpreter, **kwargs):
try:
signal.signal(signal.SIGINT, signal_handler)
spinner.start()
return loop.run_until_complete(chat_auto.achat(stream=True))
return loop.run_until_complete(chat_auto.achat())
except KeyboardInterrupt:
builtins.print("\033[91m\nOperation cancelled by user.\033[0m")
tasks = asyncio.all_tasks(loop=loop)
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
try:
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
loop.run_until_complete(asyncio.sleep(0.1))
except asyncio.CancelledError:
pass
return None
finally:
signal.signal(signal.SIGINT, original_handler)
spinner.stop()
loop.stop()
loop.close()
try:
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
2 changes: 1 addition & 1 deletion r2ai/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ async def chat(ai, message, cb):

chat_auto = ChatAuto(model, interpreter=ai, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb)

return await chat_auto.achat(stream=True)
return await chat_auto.achat()