-
Notifications
You must be signed in to change notification settings - Fork 6
/
pred_vae.py
executable file
·67 lines (53 loc) · 2.15 KB
/
pred_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
'''
Created on Jan, 2017
@author: hugo
'''
from __future__ import absolute_import
import argparse
import numpy as np
from autoencoder.core.vae import VarAutoEncoder, load_vae_model
from autoencoder.core.deepae import DeepAutoEncoder
from autoencoder.preprocessing.preprocessing import load_corpus, doc2vec
from autoencoder.utils.op_utils import vecnorm, revdict
from autoencoder.utils.io_utils import dump_json, write_file
# def get_topics(vae, vocab, topn=10):
# topics = []
# weights = vae.encoder.get_weights()[0]
# for idx in range(ae.dim):
# token_idx = np.argsort(weights[:, idx])[::-1][:topn]
# topics.append([vocab[x] for x in token_idx])
# return topics
# def print_topics(topics):
# for i in range(len(topics)):
# str_topic = ' + '.join(['%s * %s' % (prob, token) for token, prob in topics[i]])
# print 'topic %s:' % i
# print str_topic
# print
def test(args):
corpus = load_corpus(args.input)
vocab, docs = corpus['vocab'], corpus['docs']
n_vocab = len(vocab)
doc_keys = docs.keys()
X_docs = []
for k in doc_keys:
X_docs.append(vecnorm(doc2vec(docs[k], n_vocab), 'logmax1', 0))
del docs[k]
X_docs = np.r_[X_docs]
vae = load_vae_model(args.load_model)
doc_codes = vae.predict(X_docs)
dump_json(dict(zip(doc_keys, doc_codes.tolist())), args.output)
print 'Saved doc codes file to %s' % args.output
# if args.save_topics:
# topics = get_topics(vae, revdict(vocab), topn=10)
# write_file(topics, args.save_topics)
# print 'Saved topics file to %s' % args.save_topics
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, required=True, help='path to the input corpus file')
parser.add_argument('-o', '--output', type=str, required=True, help='path to the output doc codes file')
parser.add_argument('-lm', '--load_model', type=str, required=True, help='path to the trained model file')
# parser.add_argument('-st', '--save_topics', type=str, help='path to the output topics file')
args = parser.parse_args()
test(args)
if __name__ == '__main__':
main()