-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
178 lines (136 loc) · 6.1 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:light
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# # Running Llama Guard inference
#
# This notebook is intented to showcase how to run Llama Guard inference on a sample prompt for testing.
# +
# # !pip install --upgrade huggingface_hub
# +
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
from typing import List, Optional, Tuple, Dict
from enum import Enum
import torch
from tqdm import tqdm
# -
# # Defining the main functions
#
# Agent type enum to define what type of inference Llama Guard should be doing, either User or Agent.
#
# The llm_eval function loads the Llama Guard model from Hugging Face. Then iterates over the prompts and generates the results for each token.
# +
class AgentType(Enum):
AGENT = "Agent"
USER = "User"
def llm_eval(prompts: List[Tuple[List[str], AgentType]],
model_id: str = "meta-llama/Meta-Llama-Guard-2-8B",
llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_2.name,
load_in_8bit: bool = True,
load_in_4bit: bool = False,
logprobs: bool = False) -> Tuple[List[str], Optional[List[List[Tuple[int, float]]]]]:
"""
Runs Llama Guard inference with HF transformers. Works with Llama Guard 1 or 2
This function loads Llama Guard from Hugging Face or a local model and
executes the predefined prompts in the script to showcase how to do inference with Llama Guard.
Parameters
----------
prompts : List[Tuple[List[str], AgentType]]
List of Tuples containing all the conversations to evaluate. The tuple contains a list of messages that configure a conversation and a role.
model_id : str
The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files,
or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/Meta-Llama-Guard-2-8B'.
llama_guard_version : LlamaGuardVersion
The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_2.
load_in_8bit : bool
defines if the model should be loaded in 8 bit. Uses BitsAndBytes. Default True
load_in_4bit : bool
defines if the model should be loaded in 4 bit. Uses BitsAndBytes and nf4 method. Default False
logprobs: bool
defines if it should return logprobs for the output tokens as well. Default False
"""
try:
llama_guard_version = LlamaGuardVersion[llama_guard_version]
except KeyError as e:
raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
tokenizer = AutoTokenizer.from_pretrained(model_id)
torch_dtype = torch.bfloat16
# if load_in_4bit:
# torch_dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
results: List[str] = []
if logprobs:
result_logprobs: List[List[Tuple[int, float]]] = []
total_length = len(prompts)
progress_bar = tqdm(colour="blue", desc=f"Prompts", total=total_length, dynamic_ncols=True)
for prompt in prompts:
formatted_prompt = build_default_prompt(
prompt["agent_type"],
create_conversation(prompt["prompt"]),
llama_guard_version)
input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
prompt_len = input["input_ids"].shape[-1]
output = model.generate(**input, max_new_tokens=10, pad_token_id=0, return_dict_in_generate=True, output_scores=logprobs)
if logprobs:
transition_scores = model.compute_transition_scores(
output.sequences, output.scores, normalize_logits=True)
generated_tokens = output.sequences[:, prompt_len:]
if logprobs:
temp_logprobs: List[Tuple[int, float]] = []
for tok, score in zip(generated_tokens[0], transition_scores[0]):
temp_logprobs.append((tok.cpu().numpy(), score.cpu().numpy()))
result_logprobs.append(temp_logprobs)
prompt["logprobs"] = temp_logprobs
result = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
prompt["result"] = result
results.append(result)
progress_bar.update(1)
progress_bar.close()
return (results, result_logprobs if logprobs else None)
# -
def main():
prompts: List[Dict[List[str], AgentType]] = [
{
"prompt": ["<Sample user prompt>"],
"agent_type": AgentType.USER
},
{
"prompt": ["<Sample user prompt>", "<Sample agent response>"],
"agent_type": AgentType.AGENT
},
{
"prompt": ["<Sample user prompt>",
"<Sample agent response>",
"<Sample user reply>",
"<Sample agent response>"],
"agent_type": AgentType.AGENT
}
]
results = llm_eval(prompts, load_in_8bit = False, load_in_4bit = True)
for i, prompt in enumerate(prompts):
print(prompt['prompt'])
print(f"> {results[0][i]}")
print("\n==================================\n")
# used to be able to import this script in another notebook and not run the main function
if __name__ == '__main__' and '__file__' not in globals():
# from huggingface_hub import login
# login()
main()