-
Notifications
You must be signed in to change notification settings - Fork 29
/
uncertainty_estimation.py
230 lines (197 loc) · 10.3 KB
/
uncertainty_estimation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# This file used to generate uncertainty score for each question
from utils import *
import time
import argparse
import numpy as np
import json
from scipy.stats import entropy
def main():
args = arg_parser()
print('*****************************')
print(args)
print('*****************************')
print(f"API_KEY: {API_KEY}")
set_random_seed(args.random_seed)
dataloader = create_dataloader(args)
if args.dataset_size > 1000:
dataloader = dataloader[:1000] # only take 1000 questions randomly to annotate, randomness decided by seed
print(f"Dataloader size: {len(dataloader)}")
if args.qes_limit == 0:
args.qes_limit = len(dataloader)
start =time.time()
result = create_uncertainty(args, dataloader)
end = time.time()
print('Total Execution Time: ', end - start, " seconds")
# output the results
path = f"{args.output_dir}/uncertainty_result_{args.dataset}_k{args.num_trails}.txt"
with open(path, 'w') as f:
try:
f.write(json.dumps(result, indent=4))
except:
for item in result:
try:
if args.dataset in ("gsm8k", "asdiv", "svamp", "singleeq", "addsub", "multiarith"):
f.write(f"{item}, uncertainty: {len(item[-1])}, variance: {item[1]}\n")
else:
f.write(f"{item}, uncertainty: {len(item[-1])}\n")
except:
pass
def generate_uncertainty_qes(args, question):
if args.method == "few_shot_cot":
given_prompt = create_input_prompt(args, True)
if args.dataset in ("gsm8k", "asdiv", "svamp", "singleeq", "addsub", "multiarith"):
# the float is reserved for variance calculation result
uncertainty_record = {'dataset_idx':question['question_idx'], 'variance':float, 'entropy':float, 'occurrence':{}}
elif args.dataset == "strategyqa":
uncertainty_record = {'dataset_idx':question['question_idx'], 'entropy':float, 'occurrence':{"yes":0, "no":0}}
else:
uncertainty_record = {'dataset_idx':question['question_idx'], 'entropy':float, 'occurrence':{}}
for trail in range(args.num_trails):
# if zero-shot to generate uncertainty, construct first stage zero-shot prompt (step by step)
if args.method == "few_shot_cot":
prompt = given_prompt + "Q: " + question['question'] + "\nA: Let's think step by step."
elif args.method == "zero_shot_cot":
prompt = "Q: " + question['question'] + "\nA: Let's think step by step."
prompt_list = [prompt]
# if use zero-shot, here we get the first stage zero-shot result
# if not use zero-shot, here we get the final output
responses = GPT3_request(model=args.model, input_prompt=prompt_list, max_tokens=args.max_length_cot, time_interval=args.api_time_interval
, temperature=args.temperature , stop=['Question:', "Q:"])
# construct second stage prompt, to generate a single arabic num answer
if args.method == "zero_shot_cot":
prompt_list[0] += responses['choices'][0]['text'] + args.direct_answer_trigger
# get the second stage zero-shot rationale result -> arabic num answer
responses = GPT3_request(model=args.model, input_prompt=prompt_list, max_tokens=args.max_length_cot, time_interval=args.api_time_interval,
temperature=args.temperature, stop='.')
# extract the pred answer
pred_ans = answer_extraction(args, responses)
# check uncertainty
if pred_ans != "":
if pred_ans in uncertainty_record['occurrence']:
uncertainty_record['occurrence'][pred_ans] += 1 # increment answer occurrence
else:
uncertainty_record['occurrence'][pred_ans] = 1 # first occurence
else:
# Handle no solution case
if NO_SOLUTION in uncertainty_record['occurrence']:
uncertainty_record['occurrence'][NO_SOLUTION] += 1
else:
uncertainty_record['occurrence'][NO_SOLUTION] = 1
# calculate the variance for the question (only applied to datasets with numerical answer)
if args.dataset in ("gsm8k", "asdiv", "svamp", "singleeq", "addsub", "multiarith"):
ans_list = []
for ans, occurs in uncertainty_record['occurrence'].items():
for i in range(int(occurs)):
ans_list.append(float(ans))
uncertainty_record['variance'] = np.var(ans_list)
# calculate the entropy for all dataset
frequency_list = list(uncertainty_record['occurrence'].values())
uncertainty_record['entropy'] = entropy(frequency_list)
# calculate the disagreement for all dataset
uncertainty_record['disagreement'] = len(uncertainty_record['occurrence'])
return uncertainty_record
# return a sorted list by uncertainty from high to low
def create_uncertainty(args, questions):
result = []
count = 0
for qes in questions:
if count == args.qes_limit:
break
uncertainty_record = generate_uncertainty_qes(args, qes)
result.append(uncertainty_record)
count += 1
if args.sort_by == "disagreement":
if args.dataset == "strategyqa":
try:
# sort based on the entropy or the difference between yes and no answers
result.sort(key=lambda x: abs(x['occurrence']['yes'] - x['occurrence']['no']))
except:
# sort by disagreement
result.sort(key=lambda x: -len(x['occurrence']))
else:
result.sort(key=lambda x: -len(x['occurrence']))
elif args.sort_by == "variance" and args.dataset in ("gsm8k", "asdiv", "svamp", "singleeq", "addsub", "multiarith"):
# sort by variance
result.sort(key=lambda x: -x['variance'])
elif args.sort_by == "entropy" :
result.sort(key=lambda x:-x['entropy'])
return result
def arg_parser():
parser = argparse.ArgumentParser(description="Uncertainty_Generation")
parser.add_argument("--random_seed", type=int, default=1, help="random seed")
parser.add_argument(
"--dataset", type=str, default="gsm8k", choices=["gsm8k","svamp", "aqua", "csqa", "last_letters", "strategyqa", "asdiv", "singleeq", "addsub", "multiarith, time_zone"], help="dataset to inference"
)
parser.add_argument(
"--prompt_path", type=str, default="./basic_cot_prompts/math_word_problems", help="prompts to use"
)
parser.add_argument(
"--model", type=str, default="code-davinci-002", choices=["text-davinci-002", "code-davinci-002"], help="model used for decoding."
)
parser.add_argument(
"--method", type=str, default="few_shot_cot", choices=["zero_shot_cot", "few_shot_cot"], help="method"
)
parser.add_argument(
"--output_dir", type=str, default="./", help="output directory"
)
parser.add_argument(
"--max_length_cot", type=int, default=256, help="maximum length of output tokens by model for reasoning extraction"
)
parser.add_argument(
"--qes_limit", type=int, default=0, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
)
parser.add_argument(
"--api_time_interval", type=float, default=1.0, help="how many seconds sleep between each request"
)
parser.add_argument(
"--temperature", type=float, default=0.7, help=""
)
parser.add_argument(
"--num_trails", type=int, default=5, help="number of trails to run for each qeestion"
)
parser.add_argument(
"--sort_by", type=str, default='disagreement', choices=['disagreement', 'variance', 'entropy'], help="sort the final result by given option"
)
parser.add_argument(
"--concat_length", type=int, default=2, help='Used for task last_letters, indicates length of last letter concat'
)
args = parser.parse_args()
# Fill in the dataset path
if args.dataset == "gsm8k":
args.dataset_path = "./dataset/GSM8K/train.jsonl" # train data path
args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
elif args.dataset == "svamp":
args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
elif args.dataset == "asdiv":
args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
elif args.dataset == "singleeq":
args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
elif args.dataset == "addsub":
args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
elif args.dataset == "multiarith":
args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
elif args.dataset == "aqua":
args.dataset_path = "./dataset/AQuA/train.json" # train data path
args.direct_answer_trigger = "\nThe answer is"
elif args.dataset == "csqa":
args.dataset_path = "./dataset/CSQA/train_rand_split.jsonl" # train data path
args.direct_answer_trigger = "\nSo the answer is"
elif args.dataset == "strategyqa":
args.dataset_path = "./dataset/strategyQA/train.json" # train data path
args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
elif args.dataset == "last_letters":
args.dataset_path = "./dataset/last_letters/last_letters_train2.json" # train data path
args.direct_answer_trigger = "\nTherefore, the answer is"
elif args.dataset == 'time_zone':
args.dataset_path = "./dataset/timezone_convert/timezone_convertion_train.json"
else:
raise ValueError("dataset is not properly defined ...")
# "Therefore, the answer ..." -> "The answer ..."
trigger = args.direct_answer_trigger.replace("\nTherefore, ", "")
args.direct_answer_trigger_for_zeroshot = trigger[0].upper() + trigger[1:]
args.direct_answer_trigger_for_zeroshot_cot = args.direct_answer_trigger
args.direct_answer_trigger_for_fewshot = "The answer is"
args.cot_trigger = "Let's think step by step."
return args
if __name__ == "__main__":
main()