Skip to content

Commit

Permalink
Merge pull request #25 from ZeekoZhu/patch-1
Browse files Browse the repository at this point in the history
add more models
  • Loading branch information
WongSaang authored Jul 5, 2023
2 parents 39fc67d + 0bfa563 commit df7d9b9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ name = "pypi"
[packages]
django = "==4.1.7"
gunicorn = "==20.1.0"
openai = "~=0.27.2"
openai = "~=0.27.8"
psycopg2 = "~=2.9.5"
python-dotenv = "~=0.21.1"
dj-database-url = "~=1.2.0"
djangorestframework = "~=3.14.0"
tiktoken = "~=0.3.2"
tiktoken = "~=0.4.0"
djangorestframework-simplejwt = "~=5.2.2"
mysqlclient = "~=2.1.1"
django-allauth = "~=0.52.0"
Expand Down
52 changes: 41 additions & 11 deletions chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,20 @@ def delete_all(self, request):
'gpt-4': {
'name': 'gpt-4',
'max_tokens': 8192,
'max_prompt_tokens': 6196,
'max_prompt_tokens': 6192,
'max_response_tokens': 2000
},
'gpt-3.5-turbo-16k': {
'name': 'gpt-3.5-turbo-16k',
'max_tokens': 16384,
'max_prompt_tokens': 12384,
'max_response_tokens': 4000
},
'gpt-4-32k': {
'name': 'gpt-4-32k',
'max_tokens': 32768,
'max_prompt_tokens': 24768,
'max_response_tokens': 8000
}
}

Expand Down Expand Up @@ -784,13 +796,19 @@ def num_tokens_from_text(text, model="gpt-3.5-turbo-0301"):
encoding = tiktoken.get_encoding("cl100k_base")

if model == "gpt-3.5-turbo":
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-0301")
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-0613")
elif model == "gpt-3.5-turbo-16k":
print("Warning: gpt-3.5-turbo-16k may change over time. Returning num tokens assuming gpt-3.5-turbo-16k-0613.")
return num_tokens_from_text(text, model="gpt-3.5-turbo-16k-0613")
elif model == "gpt-4":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return num_tokens_from_text(text, model="gpt-4-0314")
return num_tokens_from_text(text, model="gpt-4-0613")
elif model == "gpt-4-32k":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return num_tokens_from_text(text, model="gpt-4-32k-0613")

if model not in ["gpt-3.5-turbo-0301", "gpt-4-0314"]:
if model not in ["gpt-3.5-turbo-0613", "gpt-4-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-32k-0613"]:
raise NotImplementedError(f"""num_tokens_from_text() is not implemented for model {model}.""")

return len(encoding.encode(text))
Expand All @@ -804,17 +822,29 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301"):
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif model == "gpt-3.5-turbo-16k":
print("Warning: gpt-3.5-turbo-16 may change over time. Returning num tokens assuming gpt-3.5-turbo-16k-0613.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-16k-0613")
elif model == "gpt-4":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_messages(messages, model="gpt-4-0613")
elif model == "gpt-4-32k":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_messages(messages, model="gpt-4-32k-0613")
elif model == "gpt-3.5-turbo-0613":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
elif model == "gpt-3.5-turbo-16k-0613":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0613":
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-4-32k-0613":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
num_tokens = 0
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mysqlclient==2.1.1
numexpr==2.8.4
numpy==1.24.3
oauthlib==3.2.2
openai==0.27.6
openai==0.27.8
openapi-schema-pydantic==1.2.4
packaging==23.1
platformdirs==3.5.0
Expand All @@ -61,7 +61,7 @@ soupsieve==2.4.1
SQLAlchemy==2.0.12
sqlparse==0.4.4
tenacity==8.2.2
tiktoken==0.3.3
tiktoken==0.4.0
tomli==2.0.1
tomlkit==0.11.8
tqdm==4.65.0
Expand Down

1 comment on commit df7d9b9

@gekowa
Copy link

@gekowa gekowa commented on df7d9b9 Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Please sign in to comment.