-
Notifications
You must be signed in to change notification settings - Fork 0
/
eda_pac_mia.py
58 lines (53 loc) · 3.25 KB
/
eda_pac_mia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
from utils import *
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
def compute_eda_pac(args):
dataset_names, length_list = get_dataset_list(args)
model = GPTNeoXForCausalLM.from_pretrained(
f"EleutherAI/pythia-{args.model_size}-deduped",
revision="step143000",
cache_dir=f"./pythia-{args.model_size}-deduped/step143000",
torch_dtype=torch.bfloat16,
#load_in_8bit=True,
device_map=args.cuda
#quantization_config=bnb_config,
).eval()#.to(args.cuda)
model = model.to_bettertransformer()
device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")
#model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
model = model.to(device)
model.eval()
#model.to(device)
tokenizer = AutoTokenizer.from_pretrained(
f"EleutherAI/pythia-{args.model_size}-deduped",
revision="step143000",
cache_dir=f"./pythia-{args.model_size}-deduped/step143000",
)
tokenizer.pad_token = tokenizer.eos_token
for dataset_idx in list(range(args.dataset_idx, 3)):
args.dataset_idx = dataset_idx
for min_len in length_list:
args.min_len = min_len
for dataset_name in dataset_names:
if os.path.exists(f"{args.save_dir}_{args.dataset_idx}/{dataset_name}/{args.relative}/{args.truncated}/{args.min_len}_{args.model_size}_eda_pac_dict.pkl"):
print(f"{dataset_idx} {dataset_name} {args.min_len} {args.model_size} finished")
continue
df = pd.DataFrame()
dataset = obtain_dataset(dataset_name, args)
eda_pac_dict = {}
idx_dict = {}
eda_pac_dict[dataset_name] = {"member": [], "nonmember": []}
idx_dict[dataset_name] = {"member": [], "nonmember": []}
for split in ["member", "nonmember"]:
eda_pac_list, idx_list = eda_pac_collection(model, tokenizer, dataset[split],dataset_name, args, min_len = args.min_len)
eda_pac_dict[dataset_name][split].extend(eda_pac_list)
idx_dict[dataset_name][split].extend(idx_list)
os.makedirs(f"{args.save_dir}_{args.dataset_idx}", exist_ok=True)
os.makedirs(f"{args.save_dir}_{args.dataset_idx}/{dataset_name}/{args.relative}/{args.truncated}", exist_ok=True)
pickle.dump(idx_dict, open(f"{args.save_dir}_{args.dataset_idx}/{dataset_name}/{args.relative}/{args.truncated}/{args.min_len}_{args.model_size}_idx_list.pkl", "wb"))
pickle.dump(eda_pac_dict, open(f"{args.save_dir}_{args.dataset_idx}/{dataset_name}/{args.relative}/{args.truncated}/{args.min_len}_{args.model_size}_eda_pac_dict.pkl", "wb"))
df = results_caculate_and_draw(dataset_name, args, df, method_list=["eda_pac"])
if args.same_length:
df.to_csv(f"{args.save_dir}_{args.dataset_idx}/{dataset_name}/{args.relative}/{args.truncated}/{args.min_len}_{args.model_size}_same_length.csv")
else:
df.to_csv(f"{args.save_dir}_{args.dataset_idx}/{dataset_name}/{args.relative}/{args.truncated}/{args.min_len}_{args.model_size}_all_length.csv")