-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_Implicit_Regularization.m
127 lines (113 loc) · 4.31 KB
/
plot_Implicit_Regularization.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
function plot_Implicit_Regularization(lr, alpha, NUM_itr, init, Gammas, Loss_Type, delta)
figure()
y = alpha * [16; 18; 0];
% Applying the updating method over different
switch Loss_Type
case "L2"
for j = 1:length(Gammas)
gamma = Gammas(j);
x = [[3, -1, 0]; [4, 2, 0]; [0, gamma, gamma]];
w = [init, init, init];
w1 = zeros(1, NUM_itr);
w2 = zeros(1, NUM_itr);
w3 = zeros(1, NUM_itr);
% Updating the w
for i = 1:NUM_itr
out1 = x * w';
out2 = (out1 + abs(out1)) / 2;
w = w + lr * sign(out2') .* (y' - out2') * x;
w1(i) = w(1);
w2(i) = w(2);
w3(i) = w(3);
end
plot3(w1, w2,w3, 'DisplayName',"gamma = " + num2str(gamma),'LineWidth',6);
grid on
hold on
end
title("L2 loss")
case "L1"
for j = 1:length(Gammas)
gamma = Gammas(j);
x = [[3, -1, 0]; [4, 2, 0]; [0, gamma, gamma]];
w = [init, init, init];
w1 = zeros(1, NUM_itr);
w2 = zeros(1, NUM_itr);
w3 = zeros(1, NUM_itr);
% Updating the w
for i = 1:NUM_itr
out1 = x * w';
out2 = (out1 + abs(out1)) / 2;
w = w + lr * sign(out2') .* sign(y' - out2') * x;
w1(i) = w(1);
w2(i) = w(2);
w3(i) = w(3);
end
plot3(w1, w2,w3, 'DisplayName',"gamma = " + num2str(gamma),'LineWidth',6);
grid on
hold on
end
title("L1 loss")
case "Log-Cosh"
for j = 1:length(Gammas)
gamma = Gammas(j);
x = [[3, -1, 0]; [4, 2, 0]; [0, gamma, gamma]];
w = [init, init, init];
w1 = zeros(1, NUM_itr);
w2 = zeros(1, NUM_itr);
w3 = zeros(1, NUM_itr);
% Updating the w
for i = 1:NUM_itr
out1 = x * w';
out2 = (out1 + abs(out1)) / 2;
w = w + lr * sign(out2') .* tanh(y' - out2') * x;
w1(i) = w(1);
w2(i) = w(2);
w3(i) = w(3);
end
plot3(w1, w2,w3, 'DisplayName',"gamma = " + num2str(gamma),'LineWidth',6);
grid on
hold on
end
title("Log-Cosh loss")
case "Huber_Loss"
for j = 1:length(Gammas)
gamma = Gammas(j);
x = [[3, -1, 0]; [4, 2, 0]; [0, gamma, gamma]];
w = [init, init, init];
w1 = zeros(1, NUM_itr);
w2 = zeros(1, NUM_itr);
w3 = zeros(1, NUM_itr);
% Updating the w
for i = 1:NUM_itr
out1 = x * w';
out2 = (out1 + abs(out1)) / 2;
out3 = (y' - out2');
if abs(out3(1)) > delta
out3(1) = sign(out3(1));
end
if abs(out3(2)) > delta
out3(2) = sign(out3(2));
end
if abs(out3(3)) > delta
out3(3) = sign(out3(3));
end
w = w + lr * sign(out2') .* (out3) * x;
w1(i) = w(1);
w2(i) = w(2);
w3(i) = w(3);
end
plot3(w1, w2,w3, 'DisplayName',"gamma = " + num2str(gamma),'LineWidth',6);
grid on
hold on
end
title("Huber loss")
otherwise
disp("Loss function unkown")
end
xx = [5,5];
yy = [-1,-1];
zz = [-0.25, 0.05];
plot3(xx, yy,zz, 'Color', [192/255,192/255,192/255], 'DisplayName','w1 = 5 & w2 = -1','LineWidth',3);
grid on
legend
end