diff --git a/r2ai/auto.py b/r2ai/auto.py index 05ca2af..2a68377 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -202,7 +202,7 @@ async def attempt_completion(self): "max_tokens": self.max_tokens, "stream": stream, } - res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args) + res = create_chat_completion(self.interpreter, messages=self.messages, tools=self.tools, **args) if args['stream']: return self.async_response_generator(res) else: diff --git a/r2ai/completion.py b/r2ai/completion.py index fe506b1..61d09f2 100644 --- a/r2ai/completion.py +++ b/r2ai/completion.py @@ -2,12 +2,7 @@ import traceback from . import LOGGER import json -from llama_cpp.llama_types import * -from llama_cpp.llama_grammar import LlamaGrammar -from llama_cpp.llama import StoppingCriteriaList, LogitsProcessorList -from typing import List, Iterator, Dict, Any, Optional, Union, Callable, Sequence, Generator import uuid -import llama_cpp import re from .partial_json_parser import parse_incomplete_json @@ -16,6 +11,7 @@ def messages_to_prompt(self, messages, tools=None): # Happens if it immediatly writes code if "role" not in message: message["role"] = "assistant" + lowermodel = self.model.lower() if "q4_0" in lowermodel: formatted_messages = template_q4im(self, messages) @@ -41,8 +37,6 @@ def messages_to_prompt(self, messages, tools=None): formatted_messages = template_ferret(self, messages) elif "phi" in lowermodel: formatted_messages = template_phi3(self, messages) - elif "coder" in lowermodel: - formatted_messages = template_alpaca(self, messages) elif "deepseek" in lowermodel: formatted_messages = template_alpaca(self, messages) elif "llama-3.2" in lowermodel or "llama-3.1" in lowermodel: @@ -75,6 +69,8 @@ def messages_to_prompt(self, messages, tools=None): formatted_messages = template_llamapython(self, messages) elif "tinyllama" in lowermodel: formatted_messages = template_tinyllama(self, messages) + elif "coder" in lowermodel: + formatted_messages = template_alpaca(self, messages) else: formatted_messages = template_llama(self, messages) LOGGER.debug(formatted_messages) @@ -634,8 +630,8 @@ def response_qwen(self, response): curr_line = "" fn_call = None for text in response: + full_text += text - if text == "\n": if curr_line.startswith("✿FUNCTION✿:"): fn_call = { 'name': curr_line[11:].strip(), 'id': str(uuid.uuid4()), 'arguments': None } @@ -644,6 +640,10 @@ def response_qwen(self, response): yield delta_tool_call(id, fn_call['id'], fn_call['name'], fn_call['arguments']) lines.append(curr_line) curr_line = "" + elif text == '✿' and fn_call is None: + fn_call = {} + lines.append(curr_line) + curr_line = text else: curr_line += text if curr_line.startswith("✿"): @@ -987,70 +987,24 @@ def template_llama(self,messages): formatted_messages += f"{content}[INST]" return formatted_messages -def create_chat_completion(self, **kwargs): + +def get_completion_opts(self, **kwargs): messages = kwargs.pop('messages') tools = kwargs.pop('tools') + lowermodel = self.model.lower() prompt = messages_to_prompt(self, messages, tools) - return response_to_message(self, create_completion(self.llama_instance, prompt=prompt, **kwargs)) - - -def create_completion( - self, - prompt: Union[str, List[int]], - suffix: Optional[str] = None, - max_tokens: Optional[int] = 16, - temperature: float = 0.8, - top_p: float = 0.95, - min_p: float = 0.05, - typical_p: float = 1.0, - logprobs: Optional[int] = None, - echo: bool = False, - stop: Optional[Union[str, List[str]]] = [], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repeat_penalty: float = 1.0, - top_k: int = 40, - stream: bool = False, - seed: Optional[int] = None, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_processor: Optional[LogitsProcessorList] = None, - grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, - ) -> Union[ - Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] - ]: - - prompt_tokens = self.tokenize( - prompt.encode("utf-8"), - add_bos=False, - special=True - ) - - for token in self.generate( - prompt_tokens, - top_k=top_k, - top_p=top_p, - min_p=min_p, - typical_p=typical_p, - temp=temperature, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - repeat_penalty=repeat_penalty, - stopping_criteria=stopping_criteria, - logits_processor=logits_processor, - grammar=grammar, - ): - - if llama_cpp.llama_token_is_eog(self._model.model, token): - break - text = self.detokenize([token], special=True).decode("utf-8") - yield text \ No newline at end of file + if tools and len(tools) > 0: + if "qwen" in lowermodel: + kwargs['stop'] = ["✿RESULT✿:", "✿RETURN✿:"] + + kwargs['prompt'] = prompt + return kwargs + +def create_chat_completion(self, **kwargs): + opts = get_completion_opts(self, **kwargs) + completion = create_completion(self.llama_instance, **opts) + return response_to_message(self, completion) + +def create_completion(self, **kwargs): + for item in self.create_completion(**kwargs): + yield item['choices'][0]['text'] \ No newline at end of file diff --git a/r2ai/interpreter.py b/r2ai/interpreter.py index 7d84267..214b9f8 100644 --- a/r2ai/interpreter.py +++ b/r2ai/interpreter.py @@ -382,7 +382,7 @@ def respond(self): response = auto.chat(self) else: self.llama_instance = new_get_hf_llm(self, self.model, int(self.env["llm.window"])) - response = auto.chat(self, llama_instance=self.llama_instance) + response = auto.chat(self) return elif self.model.startswith("kobaldcpp"):