Skip to content

Commit

Permalink
Fix auto rawdog mode for qwen-coder
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov committed Nov 15, 2024
1 parent c1fa76e commit c778b57
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 74 deletions.
2 changes: 1 addition & 1 deletion r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def attempt_completion(self):
"max_tokens": self.max_tokens,
"stream": stream,
}
res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args)
res = create_chat_completion(self.interpreter, messages=self.messages, tools=self.tools, **args)
if args['stream']:
return self.async_response_generator(res)
else:
Expand Down
98 changes: 26 additions & 72 deletions r2ai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
import traceback
from . import LOGGER
import json
from llama_cpp.llama_types import *
from llama_cpp.llama_grammar import LlamaGrammar
from llama_cpp.llama import StoppingCriteriaList, LogitsProcessorList
from typing import List, Iterator, Dict, Any, Optional, Union, Callable, Sequence, Generator
import uuid
import llama_cpp
import re
from .partial_json_parser import parse_incomplete_json

Expand All @@ -16,6 +11,7 @@ def messages_to_prompt(self, messages, tools=None):
# Happens if it immediatly writes code
if "role" not in message:
message["role"] = "assistant"

lowermodel = self.model.lower()
if "q4_0" in lowermodel:
formatted_messages = template_q4im(self, messages)
Expand All @@ -41,8 +37,6 @@ def messages_to_prompt(self, messages, tools=None):
formatted_messages = template_ferret(self, messages)
elif "phi" in lowermodel:
formatted_messages = template_phi3(self, messages)
elif "coder" in lowermodel:
formatted_messages = template_alpaca(self, messages)
elif "deepseek" in lowermodel:
formatted_messages = template_alpaca(self, messages)
elif "llama-3.2" in lowermodel or "llama-3.1" in lowermodel:
Expand Down Expand Up @@ -75,6 +69,8 @@ def messages_to_prompt(self, messages, tools=None):
formatted_messages = template_llamapython(self, messages)
elif "tinyllama" in lowermodel:
formatted_messages = template_tinyllama(self, messages)
elif "coder" in lowermodel:
formatted_messages = template_alpaca(self, messages)
else:
formatted_messages = template_llama(self, messages)
LOGGER.debug(formatted_messages)
Expand Down Expand Up @@ -634,8 +630,8 @@ def response_qwen(self, response):
curr_line = ""
fn_call = None
for text in response:

full_text += text

if text == "\n":
if curr_line.startswith("✿FUNCTION✿:"):
fn_call = { 'name': curr_line[11:].strip(), 'id': str(uuid.uuid4()), 'arguments': None }
Expand All @@ -644,6 +640,10 @@ def response_qwen(self, response):
yield delta_tool_call(id, fn_call['id'], fn_call['name'], fn_call['arguments'])
lines.append(curr_line)
curr_line = ""
elif text == '✿' and fn_call is None:
fn_call = {}
lines.append(curr_line)
curr_line = text
else:
curr_line += text
if curr_line.startswith("✿"):
Expand Down Expand Up @@ -987,70 +987,24 @@ def template_llama(self,messages):
formatted_messages += f"{content}</s><s>[INST]"
return formatted_messages

def create_chat_completion(self, **kwargs):

def get_completion_opts(self, **kwargs):
messages = kwargs.pop('messages')
tools = kwargs.pop('tools')
lowermodel = self.model.lower()
prompt = messages_to_prompt(self, messages, tools)
return response_to_message(self, create_completion(self.llama_instance, prompt=prompt, **kwargs))


def create_completion(
self,
prompt: Union[str, List[int]],
suffix: Optional[str] = None,
max_tokens: Optional[int] = 16,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]] = [],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.0,
top_k: int = 40,
stream: bool = False,
seed: Optional[int] = None,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> Union[
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
]:

prompt_tokens = self.tokenize(
prompt.encode("utf-8"),
add_bos=False,
special=True
)

for token in self.generate(
prompt_tokens,
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temperature,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
grammar=grammar,
):

if llama_cpp.llama_token_is_eog(self._model.model, token):
break
text = self.detokenize([token], special=True).decode("utf-8")
yield text
if tools and len(tools) > 0:
if "qwen" in lowermodel:
kwargs['stop'] = ["✿RESULT✿:", "✿RETURN✿:"]

kwargs['prompt'] = prompt
return kwargs

def create_chat_completion(self, **kwargs):
opts = get_completion_opts(self, **kwargs)
completion = create_completion(self.llama_instance, **opts)
return response_to_message(self, completion)

def create_completion(self, **kwargs):
for item in self.create_completion(**kwargs):
yield item['choices'][0]['text']
2 changes: 1 addition & 1 deletion r2ai/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def respond(self):
response = auto.chat(self)
else:
self.llama_instance = new_get_hf_llm(self, self.model, int(self.env["llm.window"]))
response = auto.chat(self, llama_instance=self.llama_instance)
response = auto.chat(self)
return

elif self.model.startswith("kobaldcpp"):
Expand Down

0 comments on commit c778b57

Please sign in to comment.