-
Notifications
You must be signed in to change notification settings - Fork 0
/
arls_runs_reddit.m
95 lines (73 loc) · 2.52 KB
/
arls_runs_reddit.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
%% Runs of CP-ARLS with leverage score sampling for the Reddit tensor
%% User setup for script
do_diary = true; % False to disable diary recording
do_gitchecks = true; % False to disable git repo checks
%% Setup recording when running the entire file
[aname, diaryname, do_diary] = diarysetup(mfilename, do_diary);
%% Check relevant git repos
if do_gitchecks
% Check Tensor Toolbox
tmp = which('tensor');
ttbdir = tmp(1:end-16); % Removing '@tensor\tensor.m'
gitstatus(ttbdir);
% Check this directory
gitstatus(pwd);
end
%% User setup for script
nruns = 4;
R = 25;
srng = 2.^(17);
truefit = 'final';
fsamp = 2^26;
%% Load in Enron Tensor
load ('/home/bwlarse/Tensors/tensor_data_reddit/reddit_log');
fprintf('---Loading in Reddit Tensor (with log counts)---\n')
X = reddit;
sz = size(X);
d = ndims(X);
ns = length(srng); % Number of s-values
%% Set up output
results = cell(nruns,ns,2);
%% Set up estimated f
fprintf('---Setting up function estimators based on X---\n');
rnginit('41ea123768800000');
xnzidx = tt_sub2ind64(sz,X.subs);
xnzidx = sort(xnzidx);
[xsubs, xvals, wghts] = tt_sample_stratified(X, xnzidx, fsamp, fsamp);
fsampler = @() deal(xsubs, xvals, wghts);
%% Run cp_leverage
rng('shuffle')
for rep = 1:nruns
fprintf('\n~~~~~Generating random initialization for run %d~~~~ \n', rep)
rnginit;
Uinit = cell(d,1);
for k = 1:d
Uinit{k} = rand(sz(k),R);
end
for sidx = 1:ns
s = srng(sidx);
fprintf('\n---Starting run %i/%i with %i samples---\n', rep, nruns, s)
sharedparams = {'init', Uinit, 'truefit', truefit, 'nsamplsq', s, 'fsampler', fsampler};
fprintf('\nFinding CP decomposition (Deterministic Inclusion): \n')
rnginit;
tic
[M, ~, info] = cp_arls_lev(X,R,'thresh', 1.0/s,sharedparams{:});
time = toc;
fprintf('Total Time (secs): %.3f\n', time)
info.params.fsampler = 'removed';
info.params.init = 'removed';
results{rep,sidx,1} = info;
fprintf('\nFinding CP decomposition (No Deterministic): \n')
rnginit;
tic
[M, ~, info] = cp_arls_lev(X,R,'thresh', [], sharedparams{:});
time = toc;
fprintf('Total Time (secs): %.3f\n', time)
info.params.fsampler ='removed';
info.params.init = 'removed';
results{rep,sidx,2} = info;
end
save('-v7.3', sprintf('%s-temp-results', aname), 'results')
end
% Save out the traces and fits for all runs
save('-v7.3', sprintf('%s-results', aname), 'results')