-
Notifications
You must be signed in to change notification settings - Fork 2
/
pred.py
73 lines (56 loc) · 2.34 KB
/
pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pickle
import torch
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from SiameseModel import ContrastiveClassifier
import numpy as np
from glob import glob
from argparse import ArgumentParser
class AICTPairWithPreGenEmbDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path):
self.dataset_path = dataset_path
print('Loading dataset...')
self.emb_files = []
self.load_data()
def __getitem__(self, idx): # per callsite
caller_embs = []
callee_embs = []
with open(self.emb_files[idx], 'rb') as f:
call_pairs = pickle.load(f)
for caller_sig, caller_emb, callee_sig, callee_emb in tqdm(call_pairs):
caller_embs.append(caller_emb)
callee_embs.append(callee_emb)
print(self.emb_files[idx])
return self.emb_files[idx], np.array(caller_embs), np.array(callee_embs)
def __len__(self):
return len(self.emb_files)
def load_data(self):
for slice_file in tqdm(glob('{}/*.pkl'.format(self.dataset_path))):
self.emb_files.append(slice_file)
if torch.cuda.is_available():
dev=torch.device('cuda')
else:
dev=torch.device('cpu')
print(dev)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-i','--emb_dir', type=str, help='embeddings dir', nargs='?', default='./aict-embeddings')
parser.add_argument('--model', type=str, help='siamese network model', nargs='?', default='./model_bce_with_pregen_emb_2.pth')
args = parser.parse_args()
model = ContrastiveClassifier(3, 100, 256, 128, 1, 256, 1).to(dev)
params_load = torch.load(args.model)['state_dict']
model.load_state_dict(params_load)
aict_loader = DataLoader(AICTPairWithPreGenEmbDataset(args.emb_dir), batch_size = 1, num_workers=0, shuffle=True)
model.eval()
icts = {}
with torch.no_grad():
for i, (binary_name, caller_embs, callee_embs) in tqdm(enumerate(aict_loader)):
binary_name = binary_name[0]
caller_embs = caller_embs.to(dev)
caller_embs = torch.squeeze(caller_embs)
callee_embs = callee_embs.to(dev)
callee_embs = torch.squeeze(callee_embs)
preds = model(caller_embs, callee_embs)
print(f'Callsite {i}, preds:', preds.cpu().numpy())