-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize_attention.py
49 lines (40 loc) · 1.48 KB
/
visualize_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import numpy as np
import time
import sys
import os
import torch
from bertviz import model_view
def show_model_view(attention, tokens, hide_delimiter_attn=False, display_mode="dark"):
if hide_delimiter_attn:
for i, t in enumerate(tokens):
if t in ("[SEP]", "[CLS]"):
for layer_attn in attention:
layer_attn[0, :, i, :] = 0
layer_attn[0, :, :, i] = 0
model_view(attention, tokens, display_mode=display_mode)
def visualize(feat_path, ind_path, label_path, attention_path):
feats = np.load(feat_path)
inds = np.load(ind_path)
labels = np.load(label_path)
attention = torch.load(attention_path)
tokens = list()
for i in range(len(inds)):
tokens.append("%s_%s"%(labels[i],inds[i]))
show_model_view(attention, tokens)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Visualize attention maps')
# parser.add_argument('--prefix_path', type=str)
parser.add_argument('--feat_path', type=str)
parser.add_argument('--ind_path', type=str)
parser.add_argument('--label_path', type=str)
parser.add_argument('--attention_path', type=str)
args = parser.parse_args()
# prefix = args.prefix_path
feat_path = args.feat_path
ind_path = args.ind_path
label_path = args.label_path
attention_path = args.attention_path
visualize(feat_path, ind_path, label_path, attention_path)