-
Notifications
You must be signed in to change notification settings - Fork 94
/
word2vec.lua
126 lines (97 loc) · 3 KB
/
word2vec.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
--[[
from here: https://github.com/rotmanmi/word2vec.torch
--]]
require 'paths'
require 'hdf5'
local cjson = require 'cjson'
torch.setdefaulttensortype('torch.FloatTensor')
local Word2vec = torch.class('Word2vec')
function Word2vec:__init(dir, classes)
self.classes = classes
self.w2v = {}
local f = dir .. '/vectors.h5'
if not paths.filep(f .. '.t7') then
print(string.format('Ready to load word vectors from %s', f))
self.w2v = self:_bintot7(dir)
else
print(string.format('Ready to load word vectors from %s', f .. '.t7'))
self.w2v = torch.load(dir ..'/vectors.h5'.. '.t7')
end
print('Done reading word2vec data.')
end
function Word2vec:_read_json(path)
local file = io.open(path, 'r')
local text = file:read()
file:close()
local info = cjson.decode(text)
return info
end
function Word2vec:_bintot7(dir)
local f = dir .. '/vectors.h5'
local h5_file = hdf5.open(f, 'r')
local vecs_size = h5_file:read('/vectors'):dataspaceSize()
local words = vecs_size[1]
local dims = vecs_size[2]
local M = h5_file:read('/vectors'):partial({1,words},{1,dims})
local rows = self:_read_json(dir .. '/rows.json')
--Writing Files
word2vec = {}
word2vec.M = torch.FloatTensor(M:size())
word2vec.w2vvocab = {}
word2vec.v2wvocab = {}
for i=1,words do
local w = rows[i]
word2vec.v2wvocab[i] = w
word2vec.w2vvocab[w] = i
--normalize to unit norm
local n = M[i]:norm()
word2vec.M[i] = M[i]/n
end
torch.save(f .. '.t7',word2vec)
print('Writing t7 File for future usage.')
return word2vec
end
function Word2vec:getVector(label)
local ind = self.w2v.w2vvocab[self.classes[label]]
return self.w2v.M[ind]
end
function Word2vec:getWordVector(label)
print(label)
local ind = self.w2v.w2vvocab[label]
print(ind)
return self.w2v.M[ind]
end
function Word2vec:eval_ranking(predictions, labels, classes, k, neg_samples)
local els = predictions:size(1)
-- normalize to have multiplication be cosine
local p_norm = predictions:norm(2,2)
predictions:cdiv(p_norm:expandAs(predictions))
-- cosine
local cosine = predictions * self.w2v.M:transpose(1,2)
local sim = 0
-- ranking
local ranking = torch.Tensor(els/(neg_samples+1))
local topk = 0
local tot = 0
for s = 1,els do
if labels[s] == 1 then
local _,index = torch.sort(cosine:select(1,s),true) -- sort rows
local ind = self.w2v.w2vvocab[self.classes[classes[s]]]
local not_found = true
local r = 1
while not_found do
if index[r]==ind then
ranking[tot+1] = r
not_found = false
if r <= k then topk = topk+1 end
end
r = r + 1
end
tot = tot + 1
sim = sim + cosine[s][s]
--print(string.format('Gold: %s -- Predicted %s ',self.classes[classes[s]], self.w2v.v2wvocab[index[1]]))
end
end
median = torch.median(ranking)[1]
return topk * 100 /tot , sim/tot, median
end