-
Notifications
You must be signed in to change notification settings - Fork 1
/
paper.tm
2603 lines (2072 loc) · 167 KB
/
paper.tm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<TeXmacs|2.1.1>
<style|<tuple|generic|padded-paragraphs>>
<\body>
<\hide-preamble>
\;
<assign|author-name|<macro|author|Yang Zhi-Han>>
</hide-preamble>
<doc-data|<doc-title|Training Latent Variable Models with Auto-encoding
Variational Bayes: A Tutorial>|<doc-author|<author-data|<\author-affiliation>
<with|font-series|bold|Yang Zhi-Han>
\;
Department of Mathematics and Statistics
Carleton College
Northfield, MN 55057
<verbatim|yangz2@carleton.edu>
</author-affiliation>>>>
<with|font-base-size|12|<abstract-data|<abstract|<with|font-base-size|12|<with|font-series|bold|Auto-encoding
Variational Bayes> (AEVB) <cite|kingma2013auto> is a powerful and general
algorithm for fitting latent variable models (a promising direction for
unsupervised learning), and is well-known for training the Variational
Auto-Encoder (VAE). In this tutorial, we focus on motivating AEVB from the
classic <with|font-series|bold|Expectation Maximization> (EM) algorithm, as
opposed to from deterministic auto-encoders. Though natural and somewhat
self-evident, the connection between EM and AEVB is not emphasized in the
recent deep learning literature, and we believe that emphasizing this
connection can improve the community's understanding of AEVB. In
particular, we find it especially helpful to view (1) optimizing the
evidence lower bound<\footnote>
It is also called the variational lower bound, or the variational bound.\
</footnote> (ELBO) with respect to inference parameters as
<with|font-series|bold|approximate E-step> and (2) optimizing ELBO with
respect to generative parameters as <with|font-series|bold|approximate
M-step>; doing both simultaneously as in AEVB is then simply tightening and
pushing up ELBO at the same time. We discuss how approximate E-step can be
interpreted as performing <with|font-series|bold|variational inference>.
Important concepts such as amortization and the reparametrization trick are
discussed in great detail. Finally, we derive from scratch the AEVB
training procedures of a non-deep and several deep latent variable models,
including VAE <cite|kingma2013auto>, Conditional VAE
<cite|sohn2015learning>, Gaussian Mixture VAE <cite|gmvae> and Variational
RNN <cite|chung2015recurrent>. It is our hope that readers would recognize
AEVB as a general algorithm that can be used to fit a wide range of latent
variable models (not just VAE), and apply AEVB to such models that arise in
their own fields of research. PyTorch <cite|paszke2019pytorch> code for all
included models are publicly available><\footnote>
Code: <slink|https://github.com/zhihanyang2022/aevb-tutorial>
</footnote>.>>>
\;
\;
\;
<\table-of-contents|toc>
<vspace*|1fn><with|font-series|bold|math-font-series|bold|font-shape|small-caps|1.<space|2spc>Latent
variable models> <datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<pageref|auto-1><vspace|0.5fn>
<vspace*|1fn><with|font-series|bold|math-font-series|bold|font-shape|small-caps|2.<space|2spc>Expectation
maximization> <datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<pageref|auto-2><vspace|0.5fn>
<vspace*|1fn><with|font-series|bold|math-font-series|bold|font-shape|small-caps|3.<space|2spc>Approximate
E-step as variational inference> <datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<pageref|auto-3><vspace|0.5fn>
<with|par-left|1tab|3.1.<space|2spc>Variational inference
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-4>>
<with|par-left|1tab|3.2.<space|2spc>Amortized variational inference
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-5>>
<with|par-left|1tab|3.3.<space|2spc>Stochastic optimization of ELBO
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-6>>
<vspace*|1fn><with|font-series|bold|math-font-series|bold|font-shape|small-caps|4.<space|2spc>Approximate
M-step> <datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<pageref|auto-7><vspace|0.5fn>
<vspace*|1fn><with|font-series|bold|math-font-series|bold|font-shape|small-caps|5.<space|2spc>Derivation
of AEVB for a few latent variable models>
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<pageref|auto-8><vspace|0.5fn>
<with|par-left|1tab|5.1.<space|2spc>What exactly is the AEVB algorithm?
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-10>>
<with|par-left|1tab|5.2.<space|2spc>Factor analysis model
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-11>>
<with|par-left|2tab|5.2.1.<space|2spc>Generative model
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-12>>
<with|par-left|2tab|5.2.2.<space|2spc>Approximate posterior
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-14>>
<with|par-left|2tab|5.2.3.<space|2spc>Estimator of per-example ELBO
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-15>>
<with|par-left|2tab|5.2.4.<space|2spc>Results
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-16>>
<with|par-left|1tab|5.3.<space|2spc>Variational autoencoder
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-19>>
<with|par-left|2tab|5.3.1.<space|2spc>Generative model
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-20>>
<with|par-left|2tab|5.3.2.<space|2spc>Approximate posterior
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-22>>
<with|par-left|2tab|5.3.3.<space|2spc>Results
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-23>>
<with|par-left|1tab|5.4.<space|2spc>Conditional VAE
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-28>>
<with|par-left|2tab|5.4.1.<space|2spc>Generative model
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-29>>
<with|par-left|2tab|5.4.2.<space|2spc>Approximate posterior
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-30>>
<with|par-left|2tab|5.4.3.<space|2spc>Estimator for per-example ELBO
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-31>>
<with|par-left|2tab|5.4.4.<space|2spc>Results
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-32>>
<with|par-left|1tab|5.5.<space|2spc>Gaussian Mixture VAE
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-34>>
<with|par-left|2tab|5.5.1.<space|2spc>Generative model
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-35>>
<with|par-left|2tab|5.5.2.<space|2spc>Approximate posterior
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-36>>
<with|par-left|2tab|5.5.3.<space|2spc>Estimator for per-example ELBO
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-37>>
<with|par-left|4tab|Estimator 1: Marginalization of
<with|mode|math|y<rsub|i>> <datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-38><vspace|0.15fn>>
<with|par-left|4tab|Estimator 2: Gumbel-Softmax trick
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-39><vspace|0.15fn>>
<with|par-left|2tab|5.5.4.<space|2spc>Results
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-40>>
<with|par-left|1tab|5.6.<space|2spc>Variational RNN
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-43>>
<with|par-left|2tab|5.6.1.<space|2spc>Generative model
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-44>>
<with|par-left|2tab|5.6.2.<space|2spc>Approximate posterior
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-45>>
<with|par-left|2tab|5.6.3.<space|2spc>Estimator for per-example ELBO
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-46>>
<with|par-left|2tab|5.6.4.<space|2spc>Results
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<no-break><pageref|auto-47>>
<vspace*|1fn><with|font-series|bold|math-font-series|bold|font-shape|small-caps|Bibliography>
<datoms|<macro|x|<repeat|<arg|x>|<with|font-series|medium|<with|font-size|1|<space|0.2fn>.<space|0.2fn>>>>>|<htab|5mm>>
<pageref|auto-49><vspace|0.5fn>
</table-of-contents>
\;
<section|Latent variable models><label|sec:latent-var-models>
In probabilitic machine learning, a <with|font-shape|italic|model> means a
(parametrized) probability distribution defined over variables of interest.
This includes classifiers and regressors, which can be viewed simply as
conditional distributions. A latent variable model is just a model that
contains some variables whose values are not observed. Therefore, for such
a model, we can divide the variables of interest into two vectors:
<math|<with|font-series|bold|x>>, which denotes the vector of observed
variables, and <math|<with|font-series|bold|z>>, which denotes the vector
of <with|font-shape|italic|latent> or unobserved variables.\
A strong motivation for using latent variable models is that some variables
in the generative process are naturally hidden from us so we cannot observe
their values. In particular, latent variables can have the interpretation
of low-dimensional \Phidden causes\Q of high-dimensional observed
variables, and models that utilize latent variables \Poften have fewer
parameters than models that directly represent correlation in the
[observed] space\Q <cite|murphy2012machine>. The low dimensionality of
latent variables also means that they can serve as a compressed
representation of data. Additionally, latent variable models can be highly
expressive from summing over or integrating over hidden variables, which
makes them useful for purposes like black-box density estimation.
<section|Expectation maximization>
In the classic statistical inference framework, fitting a model means
finding the maximum likelihood estimator (MLE) of the model parameters,
which is obtained by maximizing the log likelhood function (also known as
the evidence<\footnote>
The name \Pevidence\Q is also commonly used for
<math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|D|)>>; in this
tutorial, \Pevidence\Q strictly means <math|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|D|)>>.
</footnote>):
<\eqnarray*>
<tformat|<table|<row|<cell|<with|font-series|bold|\<theta\>><rsup|\<ast\>>>|<cell|=>|<cell|arg
max<rsub|<with|font-series|bold|\<theta\>>> log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font|cal|D>|)>,>>>>
</eqnarray*>
where we have assumed that the model <math|p<rsub|<with|font-series|bold|\<theta\>>>>
is unconditional<\footnote>
All the derivation can be easily adapted to conditional models.
</footnote> and we have <math|N> i.i.d. observations of the observed
variable <math|<with|font-series|bold|x>> stored in the dataset
<math|<with|font|cal|D>=<around*|{|<with|font-series|bold|x><rsub|1>,<with|font-series|bold|x><rsub|2>,\<cdots\>,<with|font-series|bold|x><rsub|N>|}>>.
By the same spirit, we can fit a latent variable model using MLE:
<\eqnarray*>
<tformat|<table|<row|<cell|<with|font-series|bold|\<theta\>><rsup|\<ast\>>>|<cell|=>|<cell|arg
max<rsub|<with|font-series|bold|\<theta\>>> log
\ p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font|cal|D>|)>>>|<row|<cell|>|<cell|=>|<cell|arg
max<rsub|<with|font-series|bold|\<theta\>>><big|sum><rsub|i=1><rsup|N>log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>|)>>>|<row|<cell|>|<cell|=>|<cell|arg
max<rsub|<with|font-series|bold|\<theta\>>><big|sum><rsub|i=1><rsup|N>log<big|int>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>
d<with|font-series|bold|z><rsub|i>.<eq-number><label|eq:mle>>>>>
</eqnarray*>
where we have assumed that <math|<with|font-series|bold|z><rsub|i>> is a
continuous latent variable. If <math|<with|font-series|bold|z><rsub|i>> is
discrete, then the integral would be replaced by a sum. It is also valid
for one part of <math|<with|font-series|bold|z><rsub|i>> to be continuous
and the other part to be discrete. In general, evaluating this integral is
intractable, since it's essentially the normalization constant in Bayes'
rule:
<\equation*>
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>=<frac|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|<big|int>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>
d<with|font-series|bold|z><rsub|i>>.
</equation*>
Note that this integral (or sum) is tractable in some simple cases, though
evaluating this integral and plugging it into Equation
<math|<reference|eq:mle>> still has certain downsides (see Section 11.4.1
in <cite|murphy2012machine> for a short discussion on this for Gaussian
Mixture models; see <cite|cs285slides> for a brief comment on numerical
stability; it's surprisingly difficult to find sources that discuss the
downsides more systematically). In other cases where it's not intractable,
the reason<\footnote>
Someone's personal communication with David Blei:
<slink|https://tinyurl.com/43auucww>
</footnote> for intractability can be having no closed-form solution or
computational intractability. An interested reader is encouraged to seek
additional sources.\
While directly optimizing the evidence is difficult, it is possible to
derive a lower bound to the evidence, called the
<with|font-shape|italic|evidence lower bound> (ELBO), as follows:
<\eqnarray*>
<tformat|<table|<row|<cell|<big|sum><rsub|i=1><rsup|N>log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>|)>>|<cell|=>|<cell|<big|sum><rsub|i=1><rsup|N>log
\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|<frac|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>>|]><space|1em><around*|(|<text|introduce
distributions >q<rsub|i><text|'s>|)>>>|<row|<cell|>|<cell|\<geq\>>|<cell|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
<around*|(|<frac|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>>|)>|]><space|1em><around*|(|<text|apply
Jensen's inequality>|)>>>|<row|<cell|>|<cell|=>|<cell|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|]>+<big|sum><rsub|i=1><rsup|N>\<bbb-H\><around*|(|q<rsub|i>|)>>>|<row|<cell|>|<cell|\<triangleq\>>|<cell|ELBO<around*|(|<with|font-series|bold|\<theta\>>,<around*|{|q<rsub|i>|}>|)>,>>>>
</eqnarray*>
where the notation <math|ELBO<around*|(|<with|font-series|bold|\<theta\>>,<around*|{|q<rsub|i>|}>|)>>
emphasizes that ELBO is a function of <math|<with|font-series|bold|\<theta\>>>
and <math|<around*|{|q<rsub|i>|}>>. Importantly, Jensen's inequality
becomes an equality when the random variable is a constant. This happens
when <math|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>=p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
so that <math|p<around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>/q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>>
becomes <math|p<around*|(|<with|font-series|bold|x><rsub|i>|)>>, which does
not contain <math|<with|font-series|bold|z><rsub|i>>. Therefore, if we keep
alternating between (1) setting each <math|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>>
to be <math|p<around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
so that the lower bound is tight with respect to
<math|<with|font-series|bold|\<theta\>>> and (2) maximizing the lower bound
with respect to <math|<with|font-series|bold|\<theta\>>>, then we would
maximize the evidence up to a local maximum. This is known as the
Expectation Maximization (EM) algorithm, and Step 1 is called the E-step
and Step 2 is called the M-step. The algorithm is summarized below:
<\named-algorithm|\U Expectation Maximization (EM)>
<with|font-series|bold|Require.> <math|<with|font|cal|D>=<around*|{|<with|font-series|bold|x><rsub|1>,\<ldots\>,<with|font-series|bold|x><rsub|N>|}>>:
observed data; <math|<with|font-series|bold|\<theta\>><rsub|0>>: initial
value of parameters
<math|<with|font-series|bold|\<theta\>>\<leftarrow\><with|font-series|bold|\<theta\>><rsub|0>>
<with|font-series|bold|while> <with|font-series|bold|not> converged:
<space|2em><math|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>\<leftarrow\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
for <math|i=1,\<ldots\>,N><space|12em>(Expectation step; E-step)
<space|2em><math|<with|font-series|bold|\<theta\>>\<leftarrow\> arg
max<rsub|<with|font-series|bold|\<theta\>>>
ELBO<around*|(|<with|font-series|bold|\<theta\>>,<around*|{|q<rsub|i>|}>|)>><space|14em>(Maximization
step: M-step)
<with|font-series|bold|end while>
</named-algorithm>
At this point, though we recognize that EM does correctly converge, it is
not clear whether the EM approach makes things easier and what its
consequences are, as compared to using Equation <reference|eq:mle>. To
partly answer this question, in Section 3 and 4, we will discuss a very
general and modular template for extending EM (that leads to AEVB in
Section <reference|sec:models>) to models for which the E-step (again, due
to intractability of the marginalizing integral) and the M-step are not
tractable. In Section <reference|sec:models>, we will apply this general
template to derive AEVB training procedures for several interesting latent
variable models. Equation <reference|eq:mle> cannot be extended in a
similar fashion.
Before we move on, it's important to consider an alternative derivation
(see Section 22.2.2 of <cite|pml2Book>) of ELBO that gives us more insights
on the size of the gap between the lower bound and the true objective,
which is called the <with|font-shape|italic|variational gap>. This
derivation turns out to have <with|font-shape|italic|great> importance for
later sections. In particular, if we write the marginal
<math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x>|)>>
as <math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x>,<with|font-series|bold|z>|)>/p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z>\<mid\><with|font-series|bold|x>|)>>
instead of <math|<big|int>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x>,<with|font-series|bold|z>|)>
d<with|font-series|bold|z>>, then we do not need to move the expectation
outside the log and could have the expectation outside at the beginning.
Starting from the evidence, one can show that
<\eqnarray*>
<tformat|<table|<row|<cell|<big|sum><rsub|i=1><rsup|N>log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>|)>>|<cell|=>|<cell|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>|)>|]>>>|<row|<cell|>|<cell|=>|<cell|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
<around*|(|<frac|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>|)>|]>>>|<row|<cell|>|<cell|=>|<cell|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
<around*|(|<frac|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|q<around*|(|<with|font-series|bold|z><rsub|i>|)>>\<cdot\><frac|q<around*|(|<with|font-series|bold|z><rsub|i>|)>|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>|)>|]>>>|<row|<cell|>|<cell|=>|<cell|<wide*|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>|]>|\<wide-underbrace\>><rsub|<text|ELBO><around*|(|<with|font-series|bold|\<theta\>>,<around*|{|q<rsub|i>|}>|)>>+<big|sum><rsub|i=1><rsup|N><wide*|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|<frac|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>|]>|\<wide-underbrace\>><rsub|D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>\<parallel\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|)>>.<eq-number><label|eq:ll-elbo-kl>>>>>
</eqnarray*>
We see that the gap between the evidence and ELBO is elegantly the sum of
KL divergences between the chosen distributions <math|q<rsub|i>> and the
true posteriors. Since <math|D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>\<parallel\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|)>=0>
if and only if <math|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>=p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>,
this agrees with our previous derivation using Jensen's inequality.\
<section|Approximate E-step as variational inference>
Section <reference|sec:avi> and <reference|sec:stocopt> partly follow the
treatment of <cite|cs285slides> and <cite|pml2Book> respectively.\
<\subsection>
Variational inference
</subsection>
To reduce the variational gap in Equation <reference|eq:ll-elbo-kl> (while
assuming that <math|<with|font-series|bold|\<theta\>>> is fixed) when
<math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
is intractable, we define a family of distributions
<math|<with|font|cal|Q>> and aim to find individual
<math|q<rsub|i>\<in\><with|font|cal|Q>> such that the KL divergence between
<math|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>> and
<math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
are minimized, i.e.,\
<\equation*>
q<rsup|\<ast\>><rsub|i>=arg min<rsub|q<rsub|i>\<in\><with|font|cal|Q>>
D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>\<parallel\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|)><space|1em><text|for
>i=1,\<ldots\>,N.
</equation*>
Since this is (1) optimizing over functions (probability distributions are
functions) and (2) doing inference (i.e., obtaining some representation of
the true posterior), it's called \Pvariational inference\Q. In the Calculus
of Variations, \Pvariations\Q mean small changes in functions. In practice,
<math|q<rsub|i>> would have parameters to optimize over, so we would not be
directly optimizing over functions.
If the true posterior is contained in <math|<with|font|cal|Q>> (i.e.,
<math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>\<in\><with|font|cal|Q>>),
then clearly <math|q<rsup|\<ast\>><rsub|i>=p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
and setting each <math|q<rsub|i>=q<rsup|\<ast\>><rsub|i>> in Equation
<reference|eq:ll-elbo-kl> will make ELBO a tight bound (with respect to
<math|<with|font-series|bold|\<theta\>>>) because the sum of KL divergences
will be zero. Otherwise, ELBO would not be tight, but maximizing ELBO can
still be useful because (1) it is by definition a lower bound to the
evidence, the quantity we care about, and (2)
<math|q<rsup|\<ast\>><rsub|i>> would still be a very good approximation if
<math|<with|font|cal|Q>> is flexible so ELBO would not be too loose.
The <with|font-shape|italic|challenge> of this optimization problem is that
the true posterior is not tractable and hence not available, so directly
minimizing the KL divergence is not an option. In this case, there are two
perspectives that lead to the same solution but are conceptually somewhat
different:
<\enumerate>
<item>From Equation <reference|eq:ll-elbo-kl>, we see that minimizing the
sum of KL divergences with respect to <math|q<rsub|i>>'s is equivalent to
maximizing ELBO with respect to <math|q<rsub|i>>'s, since the evidence on
the left-hand side is a constant with respect to <math|q<rsub|i>>'s.
Fortunately, ELBO is fairly easy to evaluate: we always know how to
evaluate the unnormalized posterior <math|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>>
for any graphical model, and the expectation operator outside can be
sidestepped with techniques that we will discuss in detail in Section
<reference|sec:models> on a per-model basis.
<item>One can also minimize KL divergences with respect to
<math|q<rsub|i>>'s by dealing with the
<with|font-shape|italic|unnormalized> true posteriors. This is a standard
approach: see Section 21.2 of <cite|murphy2012machine> for a textbook
treatment and Section 3 of <cite|blundell2015weight> for an application
to deep neural networks. Importantly, the unnormalized true posterior is
tractable for all graphical models, which is the foundation for
techniques such as Markov Chain Monte Carlo (MCMC). Slightly abusing the
notation of KL divergence (as <math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>>
is unnormalized), we have
<\eqnarray*>
<tformat|<table|<row|<cell|D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>\<parallel\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|)>>|<cell|=>|<cell|D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>\<parallel\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|)>+log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>|)>>>|<row|<cell|>|<cell|=>|<cell|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>><around*|[|log
q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>-log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>|]>.<space|1em><around*|(|<text|dropped
constant>|)>>>>>
</eqnarray*>
But this is just the negation of ELBO, and minimizing the negation of
ELBO with respect to <math|q<rsub|i>>'s is equivalent to maximizing ELBO
with respect to <math|q<rsub|i>>'s.
</enumerate>
While these two solutions are the same, they came from two different
derivations and can give different insights.
<subsection|Amortized variational inference><label|sec:avi>
One natural and convenient way to define a family of distributions
<math|<with|font|cal|Q>> is through a parametrized family. However, doing
this naively means that the number of parameters of <math|q<rsub|i>>'s will
grow linearly as the number of data points
<math|<with|font-series|bold|x><rsub|i>> grow. For example, if each
<math|q<rsub|i>> is an isotropic Gaussian with parameters
<math|<with|font-series|bold|\<mu\>><rsub|i>> and
<math|<with|font-series|bold|\<sigma\>><rsub|i>>, then <math|q<rsub|i>>'s
altogether would have <math|<around*|(|<around*|\||<with|font-series|bold|\<mu\>><rsub|i>|\|>+<around*|\||<with|font-series|bold|\<sigma\>><rsub|i>|\|>|)>\<times\>N>
parameters. In such cases, it is convenient to represent
<math|q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>> by a neural
network <math|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
such that <math|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>=q<rsub|i><around*|(|<with|font-series|bold|z><rsub|i>|)>>,
where <math|<with|font-series|bold|\<phi\>>> is referred to as the
<with|font-shape|italic|inference <with|font-shape|right|parameters>> (as
opposed to <math|<with|font-series|bold|\<theta\>>>, the
<with|font-shape|italic|generative> parameters). This approach is called
<with|font-shape|italic|amortized> variational inference because the
\Pcost\Q of having a fixed but large number of parameters gradually \Ppays
off\Q, in terms of memory usage and generalization benefits, as the size of
the dataset grows. With amortization, the right-hand side of Equation
<reference|eq:ll-elbo-kl> becomes
<\equation*>
<wide*|<big|sum><rsub|i=1><rsup|N>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>|\<wide-underbrace\>><rsub|<text|ELBO><around*|(|<with|font-series|bold|\<theta\>>,<with|font-series|bold|\<phi\>>|)>>+<big|sum><rsub|i=1><rsup|N><wide*|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|<frac|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>|]>|\<wide-underbrace\>><rsub|D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>\<parallel\>p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|)>>.
</equation*>
<subsection|Stochastic optimization of ELBO><label|sec:stocopt><label|sec:stocopt>
To minimize the size of the variational gap between ELBO and the evidence,
we can maximize ELBO with respect to <math|<with|font-series|bold|\<phi\>>>
with mini-batch gradient ascent <with|font-shape|italic|until convergence>:
<\eqnarray*>
<tformat|<table|<row|<cell|<with|font-series|bold|\<phi\>><rsup|t+1>>|<cell|\<leftarrow\>>|<cell|<with|font-series|bold|\<phi\>><rsup|t>+\<eta\>\<nabla\><rsub|<with|font-series|bold|\<phi\>>><around*|{|<wide*|<frac|1|N<rsub|B>><big|sum><rsub|i=1><rsup|N<rsub|B>><wide*|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>|\<wide-underbrace\>><rsub|<text|per-example
ELBO>>|\<wide-underbrace\>><rsub|<text|mini-batch
ELBO>>|}><rsub|<with|font-series|bold|\<phi\>>=<with|font-series|bold|\<phi\>><rsub|t>>>>|<row|<cell|>|<cell|=>|<cell|<with|font-series|bold|\<phi\>><rsup|t>+\<eta\>
<frac|1|N<rsub|B>><big|sum><rsub|i=1><rsup|N<rsub|B>>\<nabla\><rsub|<with|font-series|bold|\<phi\>>><around*|{|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>|}><rsub|<with|font-series|bold|\<phi\>>=<with|font-series|bold|\<phi\>><rsub|t>>,<eq-number><label|eq:e-step>>>>>
</eqnarray*>
where <math|\<eta\>\<gtr\>0> is the learning rate and <math|N<rsub|B>> is
the batch size. We have divided the mini-batch ELBO by <math|N<rsub|B>> so
that picking <math|\<eta\>> can be de-coupled from picking
<math|N<rsub|B>>. However, unless the expectations within the gradient
operators are tractable, the gradients cannot be evaluated exactly. We will
discuss solutions to this problem on a per-model basis in Section
<reference|sec:models>.
<section|Approximate M-step><label|m-step>
After an approximate E-step is completed, we can then maximize ELBO with
respect to <math|<with|font-series|bold|\<theta\>>> with mini-batch
gradient ascent <with|font-shape|italic|until convergence>:
<\eqnarray*>
<tformat|<table|<row|<cell|<with|font-series|bold|<with|font-series|bold|\<theta\>>><rsup|t+1>>|<cell|\<leftarrow\>>|<cell|<with|font-series|bold|<with|font-series|bold|\<theta\>>><rsup|t>+\<eta\>\<nabla\><rsub|<with|font-series|bold|<with|font-series|bold|\<theta\>>>><around*|{|<frac|1|N<rsub|B>><big|sum><rsub|i=1><rsup|N<rsub|B>>\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>|}><rsub|<with|font-series|bold|<with|font-series|bold|\<theta\>>>=<with|font-series|bold|\<theta\>><rsub|t>>>>|<row|<cell|>|<cell|=>|<cell|<with|font-series|bold|<with|font-series|bold|\<theta\>>><rsup|t>+\<eta\>
<frac|1|N<rsub|B>><big|sum><rsub|i=1><rsup|N<rsub|B>>\<nabla\><rsub|<with|font-series|bold|<with|font-series|bold|\<theta\>>>><around*|{|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>|}><rsub|<with|font-series|bold|<with|font-series|bold|\<theta\>>>=<with|font-series|bold|\<theta\>><rsub|t>>.<eq-number><label|eq:m-step>>>>>
</eqnarray*>
Unlike in approximate E-step, here the gradient operator with respect to
<math|<with|font-series|bold|\<theta\>>> can be moved
<with|font-shape|italic|inside> the expectation and the expectation can be
sampled. In Section 5, we will discuss this in more detail.
<section|Derivation of AEVB for a few latent variable
models><label|sec:models><float|float|t|<\big-table|<tabular|<tformat|<cwith|1|-1|1|-1|cell-halign|c>|<cwith|1|-1|1|-1|cell-valign|c>|<cwith|1|1|2|-1|font-shape|italic>|<cwith|1|-1|1|-1|cell-hyphen|n>|<cwith|1|-1|1|-1|cell-tborder|1ln>|<cwith|1|-1|1|-1|cell-bborder|1ln>|<cwith|1|-1|1|-1|cell-lborder|1ln>|<cwith|1|-1|1|-1|cell-rborder|1ln>|<table|<row|<cell|<with|font-shape|italic|Model>>|<cell|Observed>|<cell|Latent>|<cell|Joint
density / generative model>>|<row|<cell|FA>|<cell|<math|<with|font-series|bold|x>\<in\>\<bbb-R\><rsup|D>>>|<cell|<math|<with|font-series|bold|z>\<in\>\<bbb-R\><rsup|L>>>|<cell|<math|<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|>\<mid\><with|font-series|bold|z>|)>|\<wide-underbrace\>><rsub|<text|Gaussian>><wide*|p<around*|(|<with|font-series|bold|z>|)>|\<wide-underbrace\>><rsub|<text|Gaussian>>>>>|<row|<cell|VAE>|<cell|<math|<with|font-series|bold|x>\<in\><around*|{|0,1|}><rsup|784>>>|<cell|<math|<with|font-series|bold|z>\<in\>\<bbb-R\><rsup|L>>>|<cell|<math|<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|>\<mid\><with|font-series|bold|z>|)>|\<wide-underbrace\>><rsub|<text|ProductOfContinuousBernoullis>>
<wide*|p<around*|(|<with|font-series|bold|z>|)>|\<wide-underbrace\>><rsub|<text|Gaussian>>>>>|<row|<cell|CVAE>|<cell|<math|<with|font-series|bold|x>\<in\><around*|{|0,1|}><rsup|784>>>|<cell|<math|<with|font-series|bold|z>\<in\>\<bbb-R\><rsup|L>>>|<cell|<math|<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|>\<mid\><with|font-series|bold|z>,<with|font-series|bold|y>|)>|\<wide-underbrace\>><rsub|<text|ProductOfContinuousBernoullis>>
<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z>\<mid\><with|font-series|bold|y>|)>|\<wide-underbrace\>><rsub|<text|Gaussian>>>>>|<row|<cell|GMVAE>|<cell|<math|<with|font-series|bold|x>\<in\><around*|{|0,1|}><rsup|784>>>|<cell|<math|<with|font-series|bold|y>\<in\>OneHot<around*|(|C|)>,
<with|font-series|bold|z>\<in\>\<bbb-R\><rsup|L>>>|<cell|<math|<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|>\<mid\><with|font-series|bold|z>|)>|\<wide-underbrace\>><rsub|ProductOfBernoullis>
<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z>\<mid\><with|font-series|bold|y>|)>|\<wide-underbrace\>><rsub|Gaussian>
<wide*|p<rsub|<with|font-series|bold|>><around*|(|<with|font-series|bold|y>|)>|\<wide-underbrace\>><rsub|<text|OneHotCategorical>>>>>|<row|<cell|VRNN>|<cell|<math|<with|font-series|bold|x><rsub|t>\<in\><around*|{|0,1|}><rsup|28>>>|<cell|<math|<with|font-series|bold|z><rsub|t>\<in\>\<bbb-R\><rsup|L>>>|<cell|<math|<big|prod><rsub|t=1><rsup|T><wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|t>\<mid\><with|font-series|bold|x><rsub|\<less\>t>,<with|font-series|bold|z><rsub|\<leq\>t>|)>|\<wide-underbrace\>><rsub|ProductOfBernoullis>
<wide*|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|t>\<mid\><with|font-series|bold|x><rsub|\<less\>t>,<with|font-series|bold|z><rsub|t>|)>|\<wide-underbrace\>><rsub|Gaussian>>>>>>>>
Summary of latent variable models presented in this tutorial. FA means
Factor Analysis (Section <reference|sec:fa>); VAE means Variational
Auto-Encoder (<cite|kingma2013auto>; Section <reference|sec:vae>); CVAE
means Conditional VAE (<cite|sohn2015learning>; Section
<reference|sec:cvae>); GMVAE means Gaussian Mixture VAE (<cite|gmvae>;
Section <reference|sec:gmvae>); VRNN means Variational Recurrent Neural
Network (<cite|chung2015recurrent>; Section <reference|sec:vrnn>). For
all models except FA, the dataset used was MNIST (images of size
<math|28\<times\>28>). More specifically, we used normalized MNIST for
VAE and CVAE and binarized MNIST for GMVAE and VRNN. This was done to
showcase that latent variable models can have a variety of output
distributions. \ \ <label|table:all-models>
</big-table>>
In this section, we will derive the AEVB training procedure for models
listed in Table <reference|table:all-models>.
<subsection|What exactly is the AEVB algorithm?>
So far, we have discussed how E-step and M-step in EM can be approximated.
However, we haven't yet arrived at AEVB<\footnote>
We have already gotten to the \PVB\Q part; the \PVB\Q part refers to the
fact that the approximate E-step is essentially (amortized) variational
inference, which is also commonly referred to as the Variational Bayesian
approach.\
</footnote>. Compared to just performing approximate E-steps and M-steps,
AEVB makes the following additional changes. Firstly, instead of waiting
for approximate E-step and M-step to converge before moving onto one
another, AEVB performs gradient ascent with respect to
<math|<with|font-series|bold|\<phi\>>,<with|font-series|bold|\<theta\>>>
simultaneously. This can have the advantage of fast convergence, as we
share see Section <reference|sec:fa-results>. Secondly, the \PAE\Q part
refers to using a specific unbiased, low-variance and easy-to-evalaute
estimator for the per-example ELBO such that the gradient of that estimator
(with respect <math|<with|font-series|bold|\<phi\>>> and
<math|<with|font-series|bold|\<theta\>>>) is an unbiased estimator of the
gradient of the per-example ELBO. In this section, we will derive this
estimator for several interesting models and showcase PyTorch code snippets
for implementing the generative, inferential and algorithmic components.
<subsection|Factor analysis model><label|sec:fa>
<subsubsection|Generative model>
The factor analysis (FA) model is the generative model defined as follows:
<\eqnarray*>
<tformat|<table|<row|<cell|<with|font-series|bold|x><rsub|i>>|<cell|\<sim\>>|<cell|<with|font|cal|N><around*|(|<with|font-series|bold|W><with|font-series|bold|z><rsub|i>,<with|font-series|bold|\<Phi\>>|)>>>|<row|<cell|<with|font-series|bold|z><rsub|i>>|<cell|\<sim\>>|<cell|<with|font|cal|N><around*|(|0,<with|font-series|bold|I><rsub|L>|)>>>>>
</eqnarray*>
where <math|<with|font-series|bold|z><rsub|i>\<in\>\<bbb-R\><rsup|L>> is
the latent variable, <math|<with|font-series|bold|x><rsub|i>\<in\>\<bbb-R\><rsup|D>>
is the observed variable, <math|<with|font-series|bold|W>\<in\>\<bbb-R\><rsup|D\<times\>L>>
is the <with|font-shape|italic|factor loading >matrix and
<math|<with|font-series|bold|\<Phi\>>> is a diagonal covariance matrix. The
observed variable <math|<with|font-series|bold|x>> is \Pgenerated\Q by
linearly transforming the latent variable <math|<with|font-series|bold|z>>
and adding diagonal gaussian noise. We have assumed that
<math|<with|font-series|bold|x>> has zero mean, since it's trivial to
de-mean a dataset.
For simplicity, we fit a low-dimensional FA model with <math|L=2> and
<math|D=3> to a synthetic dataset generated by a ground-truth FA model with
<math|L=2> and <math|D=3>. Due to the difficulty of visualizing
3-dimensional data, we show the data projected
<math|x<rsub|1>>-<math|x<rsub|2>>, <math|x<rsub|1>>-<math|x<rsub|3>> and
<math|x<rsub|2>>-<math|x<rsub|3>> planes (Figure <reference|fig:fa-data>).
The goal is to see whether AEVB can be used to successfully fit the FA
model.<float|float|t|<\big-figure|<image|01_factor_analysis/fa_data.pdf|0.8par|||>>
Synthetic dataset (<math|n=1000>) generated by a factor analysis model
with <math|L=2> and <math|D=3>.<label|fig:fa-data>
</big-figure>>
The FA model to be fitted can be defined as a PyTorch<\footnote>
<python|nn> is the short-hand for <python|torch.nn>; <python|Ind> is the
short-hand for <python|torch.distributions.Independent>; <python|Normal>
is the short-hand for <python|torch.distributions.Normal>.
</footnote> module, which conveniently allows for learnable
<math|<with|font-series|bold|W>> and learnable standard deviation vector
(which is the diagonal of <math|<with|font-series|bold|\<Phi\>>>):\
<with|font-base-size|8|<\python-code>
class p_x_given_z_class(nn.Module):
\ \ \ \
\ \ \ \ def __init__(self):
\ \ \ \ \ \ \ \ super().__init__()
\ \ \ \ \ \ \ \ self.W = nn.Parameter(data=torch.randn(3, 2))
\ \ \ \ \ \ \ \ self.pre_sigma = nn.Parameter(data=torch.randn(3))
\ \ \ \ \ \ \ \
\ \ \ \ @property
\ \ \ \ def sigma(self):
\ \ \ \ \ \ \ \ return F.softplus(self.pre_sigma)
\ \ \ \ \ \ \ \
\ \ \ \ def forward(self, zs):
\ \ \ \ \ \ \ \ # zs shape: (batch size, 2)
\ \ \ \ \ \ \ \ mus = (self.W @ zs.T).T \ # mus shape: (batch size, 3)
\ \ \ \ \ \ \ \ return Ind(Normal(mus, sigma),
reinterpreted_batch_ndims=1) \ # sigma shape: (3, )
</python-code>>
<subsubsection|Approximate posterior>
Running AEVB requires that we define a family of approximate posteriors
<math|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>.
Fortunately, for an FA model, analytic results are available (see 12.1.2 of
<cite|murphy2012machine>): one can show that the exact posterior
<math|p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>>
is a Gaussian whose mean is related to <math|<with|font-series|bold|x>> by
a linear transformation (we shall denote this matrix by
<math|<with|font-series|bold|V>>) and whose covariance matrix is full but
independent of <math|<with|font-series|bold|x>> (we shall denote this
matrix by <with|font-series|bold|<math|\<Sigma\>>>). Therefore, we can
simply pick such Gaussians as the parametrized family<\footnote>
This family contains the true posterior, so doing variational inference
for FA turns out to be doing exact inference, except that we don't need
to derive the complicated closed-form formulas for
<math|<with|font-series|bold|V>> and <math|<with|font-series|bold|\<Sigma\>>>.
</footnote> (<math|<with|font-series|bold|\<phi\>>=<around*|(|<with|font-series|bold|V>,<with|font-series|bold|\<Sigma\>>|)>>),
and define its PyTorch<\footnote>
<python|><python|MNormal> is the short-hand for
<python|torch.distributions.MultivariateNormal>. \
</footnote> module:
<with|font-base-size|8|<\python-code>
class q_z_given_x_class(nn.Module):
\ \ \ \
\ \ \ \ def __init__(self):
\ \ \ \ \ \ \ \ super().__init__()
\ \ \ \ \ \ \ \ self.V = nn.Parameter(data=torch.randn(2, 3))
\ \ \ \ \ \ \ \ self.cov_decomp = nn.Parameter(torch.cholesky(torch.eye(2),
upper=True))
\ \ \ \ \ \ \
\ \ \ \ @property
\ \ \ \ def cov(self):
\ \ \ \ \ \ \ \ temp = torch.triu(self.cov_decomp)
\ \ \ \ \ \ \ \ return temp.T @ temp
\ \ \ \ \ \ \ \
\ \ \ \ def forward(self, xs):
\ \ \ \ \ \ \ \ # xs shape: (batch size, 3)
\ \ \ \ \ \ \ \ mus = (self.V @ xs.T).T
\ \ \ \ \ \ \ \ return MNormal(mus, self.cov)
</python-code>>
where we have followed the standard practice to learn the
<with|font-shape|italic|Cholesky decomposition> of the covariance matrix
rather than the covariance matrix directly, since entries of the
decomposition are unconstrained (i.e., real numbers) and hence more
amendable to gradient-based optimization.
<subsubsection|Estimator of per-example ELBO><label|sec:fa-aevb>
Recall that, in both the E-step and the M-step, the primary challenge is
that we need to compute gradients of per-example ELBOs: since the
expectation operators contained in per-example ELBOs are not assumed to be
tractable, we cannot evaluate these gradients exactly and must resort to
using estimators for these gradients. It turns out that we can first
construct an unbiased, low-variance estimator for each expectation rather
than for each gradient; then, as long as the source
<with|font-shape|italic|>of variability of this estimator does
<with|font-shape|italic|not> depend on the parameters (which is already
true for <math|<with|font-series|bold|\<theta\>>>), the gradient of this
unbiased estimator would be an unbiased estimator of the gradient. To
achieve this \Pindependence\Q, we apply the
<with|font-shape|italic|reparametrization trick> to the per-example ELBO as
follows:
<\eqnarray*>
<tformat|<table|<row|<cell|>|<cell|>|<cell|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>>>|<row|<cell|>|<cell|=>|<cell|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i>|)>+log
p<with|font-series|bold|><around*|(|<with|font-series|bold|z><rsub|i>|)>-log
q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>|]>>>|<row|<cell|>|<cell|=>|<cell|\<bbb-E\><rsub|<with|font-series|bold|z><rsub|i>\<sim\>q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i>|)>|]>-D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>\<parallel\>p<around*|(|<with|font-series|bold|z><rsub|i>|)>|)>>>|<row|<cell|>|<cell|=>|<cell|\<bbb-E\><rsub|<wide*|<with|font-series|bold|\<varepsilon\>><rsub|i>\<sim\><with|font|cal|N><around*|(|0,<with|font-series|bold|I><rsub|2>|)>|\<wide-underbrace\>><rsub|<text|No
longer involves <math|<with|font-series|bold|\<phi\>>!>>>><around*|[|log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i><rsup|s>|)>|]>-D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>\<parallel\>p<around*|(|<with|font-series|bold|z><rsub|i>|)>|)><space|1em><around*|(|<text|><with|font-series|bold|z><rsub|i><rsup|s>=<with|font-series|bold|V>
<with|font-series|bold|x><rsub|i>+cholesky<around*|(|<with|font-series|bold|\<Sigma\>>|)><with|font-series|bold|\<varepsilon\>><rsub|i>|)>,>>>>
</eqnarray*>
where the second KL term can be evaluated in closed form for Gaussians
(which is the case for FA), and the source of randomness in the first
expectation is indeed no longer depend on
<math|<with|font-series|bold|\<phi\>>>. We have used
<math|<text|><with|font-series|bold|z><rsub|i><rsup|s>> to denote the
<with|font-shape|italic|reparametrized sample>. This reparametrized
expression allows us to instantiate the following unbiased, low-variance
estimator of the per-example ELBO:
<\equation*>
<wide|ELBO|^><around*|(|<with|font-series|bold|x><rsub|i>,<with|font-series|bold|\<theta\>>,<with|font-series|bold|\<phi\>>|)>=log
p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i><rsup|s>|)>-D<rsub|\<bbb-K\>\<bbb-L\>><around*|(|q<rsub|<with|font-series|bold|\<phi\>>><around*|(|<with|font-series|bold|z><rsub|i>\<mid\><with|font-series|bold|x><rsub|i>|)>\<parallel\>p<with|font-series|bold|<rsub|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>|)>|)><space|1em><around*|(|<text|><with|font-series|bold|z><rsub|i><rsup|s>=<with|font-series|bold|V>
<with|font-series|bold|x><rsub|i>+cholesky<around*|(|<with|font-series|bold|\<Sigma\>>|)><with|font-series|bold|\<varepsilon\>><rsub|i>|)>,
</equation*>
where <math|log p<rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i><rsup|s>|)>>
intuitively measures how well the generative and inference components
collaboratively perform <with|font-shape|italic|reconstruction> or
<with|font-shape|italic|auto-encoding><\footnote>
Hence the name <with|font-shape|italic|Auto-Encoding> Variational Bayes.
If the expectation within the KL term is also sampled (which yields a
higher-variance estimator), the algorithm is instead called
<with|font-shape|italic|Stochastic Gradient> Variational Bayes (SGVB).
</footnote>: encoding <math|<with|font-series|bold|x><rsub|i>>
probabilistically into <math|<with|font-series|bold|z><rsub|i><rsup|s>>,
and then decoding <math|<with|font-series|bold|z><rsub|i>>
deterministically into <math|<with|font-series|bold|x><rsub|i><rsup|s>>. In
PyTorch<\footnote>
<python|kl_divergence> is the short-hand for
<python|torch.distributions.kl.kl_divergence>. Also, note that we used
<python|rsample> instead of <python|sample> \U this is
<with|font-shape|italic|crucial>; otherwise we would not be able to
differentiate through <math|<with|font-series|bold|z><rsub|i><rsup|s>>.\
</footnote>, we can implement this estimator and compute the gradient of it
with respect to <math|<with|font-series|bold|\<theta\>>> and
<math|<with|font-series|bold|\<phi\>>> as follows:
<with|font-base-size|8|<\python-code>
class AEVB(nn.Module):
\ \ \ \ # ...
\ \ \ \ def step(self, xs):
\ \ \ \ \ \ \ \ # xs shape: (batch size, 3)
\ \ \ \ \ \ \ \ posterior_over_zs = self.q_z_given_x(xs)
\ \ \ \ \ \ \ \ kl = D.kl.kl_divergence(posterior_over_zs, self.p_z)
\ \ \ \ \ \ \ \ zs = posterior_over_zs.rsample() \ # reparametrized
samples
\ \ \ \ \ \ \ \ rec = self.p_x_given_z(zs).log_prob(xs) \ #
reconstruction
\ \ \ \ \ \ \ \ per_example_elbos = rec - kl \ # values of estimators of
per-example ELBOs
\ \ \ \ \ \ \ \ mini_batch_elbo = per_example_elbos.mean()
\ \ \ \ \ \ \ \ loss = - mini_batch_elbo
\ \ \ \ \ \ \ \ self.optimizer.zero_grad()
\ \ \ \ \ \ \ \ loss.backward()
\ \ \ \ \ \ \ \ self.optimizer.step()
\ \ \ \ # ...
</python-code>>
<subsubsection|Results><label|sec:fa-results>
<with|font-series|bold|Experiment 1.><float|float|t|<\big-figure|<image|01_factor_analysis/fa_learning_curve.pdf|1par|||>>
Test set performance of the FA model across training. Red and green
curves show that estimated ELBO and evidence improves towards the
evidence of the true model (black dotted line) respectively, and that
ELBO is indeed a lower bound to the evidence. As ELBO improves, we see
that generated data (orange points) gradually matches test data (blue
points) in distribution.<label|fig:fa-learning>
</big-figure>> In this experiment, we empirically assess the convergence of
AEVB on our simple FA model. We run gradient ascent on estimated mini-batch
ELBO<\footnote>
This is simply an average of estimators of per-example ELBOs, as shown in
the code snippet.
</footnote>, and after every gradient step we measure performance of AEVB
on two metrics: the estimated ELBO on the entire test set (<math|n=1000>)
and the exact evidence on the entire test set. For training
hyperparameters, we used a mini-batch size of 32 and Adam
<cite|kingma2014adam> with a learning rate of 1e-2.
It is worth noting that, for arbitrary latent variable models, the evidence
is generally intractable since it involves accurately evaluating the
integral in Equation <reference|eq:mle>. While Monte Carlo integration (by
first obtaining samples from <math|p<around*|(|<with|font-series|bold|z><rsub|i>|)>>
and then averaging <math|p<around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i>|)>>)
can work, it has high variance when the latent space has high
dimensionality and requires too much computation to process all the
samples. Fortunately, the evidence can be expressed analytically for an FA
model:
<with|font-base-size|8|<\python-code>
class AEVB(nn.Module):
\ \ \ \ # ... \ \ \
\ \ \ \ def compute_evidence(self, xs):
\ \ \ \ \ \ \ \ # xs shape: (batch size, 3)
\ \ \ \ \ \ \ \ W = self.p_x_given_z.W
\ \ \ \ \ \ \ \ Phi = (self.p_x_given_z.sigma * torch.eye(3)) ** 2
\ \ \ \ \ \ \ \ p_x = MNormal(torch.zeros(3), Phi + W @ W.T) \ # Phi + W
@ W.T is low-rank approximation
\ \ \ \ \ \ \ \ return float(p_x.log_prob(xs).mean()) \ \ \
\ \ \ \ # ...
</python-code>>
During training, we also generate data from the learned FA model to check
whether they match the test data in distribution. This check is similar to
a posterior predictive check in Bayesian model fitting, which allows us to
see whether the model of choice is appropriate (e.g., does it underfit?).
Finally, we note that the learned FA model will not recover the parameters
of the true FA model, since the matrix <math|<with|font-series|bold|W>> is
only unique up to a right-hand side multiplication with a 2x2 rotation
matrix.
Figure <reference|fig:fa-learning> shows the results of this experiment. We
see that estimated ELBO and evidence on the test set improve over time,
eventually coming very close to the evidence of the true model. In
particular, ELBO is indeed a lower bound to the evidence. Predictive checks
at different stages of training shows that, as ELBO improves, generated
data is closer and closer to and eventually indistinguishable from test
data . These results suggest that AEVB was successful at fitting the FA
model. Empirically, we also find less expressive approximate posteriors
(e.g., with a diagonal covariance matrix, or with a diagonal covariance
matrix with fixed<\footnote>
This works well only when the fixed values are small.
</footnote> entries) to work reasonaly well.
<with|font-series|bold|Experiment 2.> <float|float|t|<\big-figure|<image|01_factor_analysis/fa_learning_curve_em.pdf|1par|||>>
Test set performance of the FA model across training when alternating
between periods of only updating inference parameters
<math|<with|font-series|bold|\<phi\>>> (gray regions) and periods of only
updating generative parameters <math|<with|font-series|bold|\<theta\>>>
(white regions). In gray regions, ELBO becomes tight; in white regions,
both ELBO and evidence improves but ELBO is no longer tight.
<label|fig:fa-em>
</big-figure>>In this experiment, we confirm that viewing AEVB through the
lense of EM is a reasonable one. In particular, we verify the following: if
we keep updating the inference parameters without updating the generative
parameters (approximate E-step), then ELBO would become a tight lower bound
to the evidence; if we then switch to keep updating the generative
parameters without updating the inference parameters (approximate M-step),
then both ELBO and evidence would improve but ELBO would no longer be a
tight lower bound. Hyperparameter values are unchanged from Experiment 1.
The experiment is designed as follows. We alternate between only updating
inference parameters for 1000 gradient steps and only updating generative
parameters for 1000 gradient steps for two rounds, which adds up 4000
gradient steps in total. As in Experiment 1, we are tracking the estimated
ELBO and evidence evaluated on the entire test set after gradient step.\
Figure <reference|fig:fa-em> shows the results for this experiment. Gray
regions represent periods during which we only updated the inference
aprameters, and white regions represent periods during which we only
updated the generative parameters. In gray regions, we see that the
evidence is fixed while ELBO gradually becomes a tight lower bound. From
hindsight, this shouldn't be surprising, since inference parameters do not
participate in the evidence computation. In white regions, we see that both
ELBO and evidence improves, but ELBO is no longer a tight lower bound.\
<subsection|Variational autoencoder><label|sec:vae>
<subsubsection|Generative model>
<float|float|t|<\big-figure|<image|02_vae/vae_mnist_orgs.pdf|0.8par|||>>
Original MNIST images.<label|fig:vae-mnist-org>
</big-figure>>The generative part of VAE <cite|kingma2013auto> for
normalized<\footnote>
Original MNIST is on the scale of 0-255 (discrete). We follow the common
practice of adding Uniform noise from 0 to 1 to each pixel and then
dividing by 256 for normalization.
</footnote> MNIST images (Figure <reference|fig:vae-mnist-org>) is defined
as follows:
<\eqnarray*>
<tformat|<table|<row|<cell|<with|font-series|bold|x><rsub|i>>|<cell|\<sim\>>|<cell|<text|Product-Of-Continuous-Bernoullis><around*|(|<with|font-series|bold|\<lambda\>><rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>|)>|)>>>|<row|<cell|<with|font-series|bold|z><rsub|i>>|<cell|\<sim\>>|<cell|<with|font|cal|N><around*|(|0,<with|font-series|bold|I><rsub|L>|)>>>>>
</eqnarray*>
where <math|<with|font-series|bold|z><rsub|i>\<in\>\<bbb-R\><rsup|L>> is
the latent variable and <math|<with|font-series|bold|x><rsub|i>\<in\>\<bbb-R\><rsup|D>>
is the observed variable. In particular, <math|D=28\<times\>28> where
<math|28> is the height and width of each image.
Product-Of-Continuous-Bernoullis is a product of <math|D> independent
continuous Bernoulli distributions (<cite|loaiza2019continuous>; each of
which has support <math|<around*|[|0,1|]>> with parameter
<math|\<lambda\>\<in\><around*|[|0,1|]>>) with PDF
<\equation*>
p<around*|(|<with|font-series|bold|x><rsub|i>\<mid\><with|font-series|bold|z><rsub|i>|)>=<big|prod><rsub|j=1><rsup|D><text|Continuous-Bernoulli><around*|(|x<rsub|i
j>\<mid\>\<lambda\><rsub|i j >|)><text|<space|1em>where<space|1em>>\<lambda\><rsub|i
j >=<with|font-series|bold|\<lambda\>><rsub|<with|font-series|bold|\<theta\>>><around*|(|<with|font-series|bold|z><rsub|i>|)><rsub|j>
</equation*>
where <math|<with|font-series|bold|\<lambda\>><rsub|<with|font-series|bold|\<theta\>>>:\<bbb-R\><rsup|L>\<rightarrow\><around*|[|0,1|]><rsup|D>>
is a neural network that maps latent vectors to parameter vectors of the
product of independent continuous Bernoullis. We can define this generative
model in PyTorch<\footnote>
<python|><python|CB> is the short-hand for
<python|torch.distributions.ContinuousBernoulli>. \
</footnote> (following the network architecture in the official TensorFlow
code<\footnote>
Code: <slink|https://github.com/cunningham-lab/cb_and_cc/blob/master/cb/utils.py>
</footnote> for <cite|loaiza2019continuous>):
<\with|font-base-size|8>
<\python-code>
class p_x_given_z_class(nn.Module):