-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_frank_wolfe_heterogeneous.py
executable file
·50 lines (41 loc) · 2.02 KB
/
test_frank_wolfe_heterogeneous.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
__author__ = "Jerome Thai"
__email__ = "jerome.thai@berkeley.edu"
import unittest
import numpy as np
from frank_wolfe_heterogeneous import fw_heterogeneous_1, fw_heterogeneous_2
from utils import braess_heterogeneous
class TestFrankWolfeHeterogeneous(unittest.TestCase):
def check(self, f, true, eps):
error = np.linalg.norm(f - true) / np.linalg.norm(true)
print 'error', error
self.assertTrue(error < eps)
def test_fw_heterogeneous_1(self):
print 'test fw_heterogeneous 1'
g1,g2,d1,d2 = braess_heterogeneous(.25, .25)
fs = fw_heterogeneous_1([g1,g2], [d1,d2], max_iter=200)
a = np.array([[.125,.25],[.125,.0],[.0, .25],[.125, .0],[.125, .25]])
self.check(fs, a, 1e-2)
g1,g2,d1,d2 = braess_heterogeneous(1., 1.)
a = np.array([[.5,.5],[.5,.5],[.0, .0],[.5, .5],[.5, .5]])
fs = fw_heterogeneous_1([g1,g2], [d1,d2], max_iter=200)
self.check(fs, a, 1e-2)
g1,g2,d1,d2 = braess_heterogeneous(.75, .75)
fs = fw_heterogeneous_1([g1,g2], [d1,d2], max_iter=200)
a = np.array([[.375, .625],[.375, .125],[.0, .5],[.375, .125],[.375, .625]])
self.check(np.sum(fs,1), np.sum(a,1), 1e-2)
def test_fw_heterogeneous_2(self):
print 'test fw_heterogeneous 2'
g1,g2,d1,d2 = braess_heterogeneous(.25, .25)
fs = fw_heterogeneous_2([g1,g2], [d1,d2], q=10, past=10)
a = np.array([[.125,.25],[.125,.0],[.0, .25],[.125, .0],[.125, .25]])
self.check(fs, a, 1e-2)
g1,g2,d1,d2 = braess_heterogeneous(1., 1.)
a = np.array([[.5,.5],[.5,.5],[.0, .0],[.5, .5],[.5, .5]])
fs = fw_heterogeneous_2([g1,g2], [d1,d2], q=10, past=10)
self.check(fs, a, 1e-2)
g1,g2,d1,d2 = braess_heterogeneous(.75, .75)
fs = fw_heterogeneous_2([g1,g2], [d1,d2], q=200, past=10, max_iter=200)
a = np.array([[.375, .625],[.375, .125],[.0, .5],[.375, .125],[.375, .625]])
self.check(np.sum(fs,1), np.sum(a,1), 1e-2)
if __name__ == '__main__':
unittest.main()