-
Notifications
You must be signed in to change notification settings - Fork 4
/
LinRegAux.hs
246 lines (196 loc) · 8.64 KB
/
LinRegAux.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# OPTIONS_GHC -fno-warn-unused-do-bind #-}
{-# OPTIONS_GHC -fno-warn-missing-methods #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module LinRegAux (
denv
, diagGamma
, diagNormal
, diagNormals
, diag
, barDiag
, MCMCAnal (..)
, wikiDiag
) where
import System.IO.Unsafe
import Graphics.Rendering.Chart.Backend.Diagrams
import Graphics.Rendering.Chart
import Numeric.SpecFunctions
import Text.Printf
import Data.Colour
import Control.Lens hiding ( (#) )
import Diagrams.Backend.Cairo.CmdLine
import Diagrams.Prelude hiding ( sample, render )
import Data.Default.Class
denv :: DEnv
denv = unsafePerformIO $ defaultEnv vectorAlignmentFns 500 500
logGammaPdf :: Double -> Double -> Double -> Double
logGammaPdf alpha beta x = unNorm - logGamma alpha
where
unNorm = alpha * (log beta) + (alpha - 1) * log x - beta * x
gammaPlot :: Double -> Double -> Double -> Double -> Graphics.Rendering.Chart.Renderable ()
gammaPlot a b a' b' = toRenderable layout
where
am :: Double -> Double
am x = exp (logGammaPdf a b x)
am' :: Double -> Double
am' x = exp (logGammaPdf a' b' x)
gammaPlot1 = plot_lines_values .~ [[ (x,(am x)) | x <- [0,(0.05)..20]]]
$ plot_lines_style . line_color .~ opaque blue
$ plot_lines_title .~ "prior shape = " ++ printf "%3.3f" (a :: Double) ++
" rate = " ++ printf "%3.3f" (b :: Double)
$ def
gammaPlot2 = plot_lines_values .~ [[ (x,(am' x)) | x <- [0,(0.05)..20]]]
$ plot_lines_style . line_color .~ opaque red
$ plot_lines_title .~ "posterior shape = " ++ printf "%3.3f" a' ++
" rate = " ++ printf "%3.3f" b'
$ def
layout = layout_title .~ "Gamma Prior and Posterior (10 Observations)"
$ layout_plots .~ [toPlot gammaPlot1,
toPlot gammaPlot2]
$ def
diagGamma :: Double -> Double -> Double -> Double -> QDiagram Cairo R2 Any
diagGamma a b a' b' = fst $ runBackend denv (render (gammaPlot a b a' b') (500, 500))
normalPdf :: Double -> Double -> Double -> Double
normalPdf mu sigma x = recip (sigma * sqrt (2.0 * pi)) * exp(-(x - mu)**2 / (2 * sigma**2))
normalPlot :: Double -> Double ->
Double -> Double ->
Double -> Double ->
Graphics.Rendering.Chart.Renderable ()
normalPlot a b a' b' a'' b'' = toRenderable layout
where
am :: Double -> Double
am x =normalPdf a b x
am' :: Double -> Double
am' x = normalPdf a' b' x
am'' :: Double -> Double
am'' x = normalPdf a'' b'' x
normalPlot1 = plot_lines_values .~ [[ (x,(am x)) | x <- [-2,(-1.95)..5]]]
$ plot_lines_style . line_color .~ opaque blue
$ plot_lines_title .~ "prior mean = " ++ printf "%3.3f" (a :: Double) ++
" var = " ++ printf "%3.3f" (b :: Double)
$ def
normalPlot2 = plot_lines_values .~ [[ (x,(am' x)) | x <- [-2,(-1.95)..5]]]
$ plot_lines_style . line_color .~ opaque red
$ plot_lines_title .~ "post mean = " ++ printf "%3.3f" a' ++
" var = " ++ printf "%3.3f" b'
$ def
normalPlot3 = plot_lines_values .~ [[ (x,(am'' x)) | x <- [-2,(-1.95)..5]]]
$ plot_lines_style . line_color .~ opaque green
$ plot_lines_title .~ "post mean = " ++ printf "%3.3f" a'' ++
" var = " ++ printf "%3.3f" b''
$ def
layout = layout_title .~ "Normal Prior and Posterior (10 & 100 Observations)"
$ layout_plots .~ [toPlot normalPlot1,
toPlot normalPlot2,
toPlot normalPlot3]
$ def
diagNormal :: Double -> Double ->
Double -> Double ->
Double -> Double ->
QDiagram Cairo R2 Any
diagNormal a b a' b' a'' b'' = fst $ runBackend denv
(render (normalPlot a b a' b' a'' b'') (500, 500))
chart :: Colour Double ->[(Double, Double)] -> Graphics.Rendering.Chart.Renderable ()
chart c prices = toRenderable layout
where
price1 = plot_points_style . point_color .~ opaque c
$ plot_points_values .~ prices
$ plot_points_title .~ "price 1"
$ def
layout = layoutlr_title .~"Price History"
$ layoutlr_left_axis . laxis_override .~ axisGridHide
$ layoutlr_right_axis . laxis_override .~ axisGridHide
$ layoutlr_x_axis . laxis_override .~ axisGridHide
$ layoutlr_plots .~ [Left (toPlot price1),
Right (toPlot price1)]
$ layoutlr_grid_last .~ False
$ def
diag :: Colour Double -> [(Double, Double)] -> QDiagram Cairo R2 Any
diag c prices = fst $ runBackend denv (render (chart c prices) (500, 500))
normalPlots :: [(Double, Double, Colour Double, String)] ->
Graphics.Rendering.Chart.Renderable ()
normalPlots abs = toRenderable layout
where
lower a b = a - 5*b
upper a b = a + 5*b
gap b = 10*b / 1000.0
am :: Double -> [Double]
am x = map (\(a, b, _, _) -> normalPdf a b x) abs
normalPlots = zipWith (
\(a, b, c, l) n ->
plot_lines_values .~ [[ (x, (am x)!!n) |
x <- [lower a b,lower a b + gap b .. upper a b]]]
$ plot_lines_style . line_color .~ opaque c
$ plot_lines_title .~ l ++ " mean = " ++ printf "%3.3f" a ++
" var = " ++ printf "%3.3f" b
$ def
) abs [0..]
layout = layout_title .~ "Normal Prior and Posterior"
$ layout_plots .~ (map toPlot normalPlots)
$ def
diagNormals :: [(Double, Double, Colour Double, String)] ->
QDiagram Cairo R2 Any
diagNormals abs = fst $ runBackend denv
(render (normalPlots abs) (500, 500))
data MCMCAnal = MCMC | Anal | MCMCAnal
barChart :: MCMCAnal ->
[(Double, Double)] ->
[(Double, Double)] ->
Graphics.Rendering.Chart.Renderable ()
barChart pt bvs bvs' = toRenderable layout
where
layout =
layout_title .~ title pt
$ layout_x_axis . laxis_generate .~ autoIndexAxis (map (printf "%3.1f" . fst) bvs)
$ layout_y_axis . laxis_title .~ "Frequency"
$ layout_plots .~ (map plotBars $ plots pt)
$ def
title MCMC = "Posterior via MCMC"
title Anal = "Analytic Posterior"
title MCMCAnal = "MCMC and Analytic Posteriors Overlaid"
plots MCMC = [ bars1 ]
plots Anal = [ bars2 ]
plots MCMCAnal = [ bars1, bars2 ]
bars1 =
plot_bars_titles .~ ["MCMC"]
$ plot_bars_values .~ addIndexes (map return $ map snd bvs)
$ plot_bars_style .~ BarsClustered
$ plot_bars_item_styles .~ [(solidFillStyle (blue `withOpacity` 0.25), Nothing)]
$ def
bars2 =
plot_bars_titles .~ ["Analytic"]
$ plot_bars_values .~ addIndexes (map return $ map snd bvs')
$ plot_bars_style .~ BarsClustered
$ plot_bars_item_styles .~ [(solidFillStyle (red `withOpacity` 0.25), Nothing)]
$ def
barDiag :: MCMCAnal ->
[(Double, Double)] ->
[(Double, Double)] ->
QDiagram Cairo R2 Any
barDiag pt bvs bvs' = fst $ runBackend denv (render (barChart pt bvs bvs') (500, 500))
wikiChart :: Bool -> Graphics.Rendering.Chart.Renderable ()
wikiChart borders = toRenderable layout
where
layout =
layout_title .~ "Sample Bars" ++ btitle
$ layout_title_style . font_size .~ 10
$ layout_x_axis . laxis_generate .~ autoIndexAxis alabels
$ layout_y_axis . laxis_override .~ axisGridHide
$ layout_left_axis_visibility . axis_show_ticks .~ False
$ layout_plots .~ [ plotBars bars2 ]
$ def :: Layout PlotIndex Double
bars2 = plot_bars_titles .~ ["Cash","Equity"]
$ plot_bars_values .~ addIndexes [[20,45],[45,30],[30,20],[70,25]]
$ plot_bars_style .~ BarsClustered
$ plot_bars_spacing .~ BarsFixGap 30 5
$ plot_bars_item_styles .~ map mkstyle (cycle defaultColorSeq)
$ def
alabels = [ "Jun", "Jul", "Aug", "Sep", "Oct" ]
btitle = if borders then "" else " (no borders)"
bstyle = if borders then Just (solidLine 1.0 $ opaque black) else Nothing
mkstyle c = (solidFillStyle c, bstyle)
wikiDiag :: Bool -> QDiagram Cairo R2 Any
wikiDiag borders = fst $ runBackend denv (render (wikiChart borders) (500, 500))