diff --git a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl index 0d9da447a..b26399c72 100644 --- a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl +++ b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl @@ -129,21 +129,19 @@ We can sample from this model, which is useful for model diagnostic and prior pr " # ╔═╡ fbe117b7-a0b8-4604-a5dd-e71a0a1a4fc3 -plt_ar_sample = let - n_samples = 100 - ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(ar_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(ar_mdl, θ) - end - - plot(ar_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "$(n_samples) draws from the AR(2) model", - ylabel = "Log Rt") +n_samples = 100 +ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ + θ = rand(ar_mdl) #Sample unconditionally the underlying parameters of the model + gen = generated_quantities(ar_mdl, θ) end +plot(ar_mdl_samples, + lab = "", + c = :grey, + alpha = 0.25, + title = "$(n_samples) draws from the AR(2) model", + ylabel = "Log Rt") + # ╔═╡ 9f84dec1-70f1-442e-8bef-a9494921549e md" And we can sample from this model with some parameters conditioned, for example with $\sigma = 0$. In this case the AR process is an initial perturbation model with return to baseline. @@ -153,21 +151,19 @@ And we can sample from this model with some parameters conditioned, for example cond_ar_mdl = ar_mdl | (σ_AR = 0.0,) # ╔═╡ d3938381-01b7-40c6-b369-a456ff6dba72 -let - n_samples = 100 - ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(cond_ar_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(cond_ar_mdl, θ) - end - - plot(ar_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "AR(2) model conditioned on sigma = 0", - ylabel = "Log Rt") +n_samples = 100 +ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ + θ = rand(cond_ar_mdl) #Sample unconditionally the underlying parameters of the model + gen = generated_quantities(cond_ar_mdl, θ) end +plot(ar_mdl_samples, + lab = "", + c = :grey, + alpha = 0.25, + title = "AR(2) model conditioned on sigma = 0", + ylabel = "Log Rt") + # ╔═╡ 12fd3bd5-657e-4b1a-aa88-6063419aaceb md" In this note, we are going to treat $R_t$ as varying every two days. The reason for this is to 1) reduce the effective number of parameters, and 2) showcase the `BroadcastLatentModel` wrapper. @@ -179,22 +175,20 @@ In `EpiAware` we set this behaviour by wrapping a `LatentModel` in a `BroadcastL twod_ar = BroadcastLatentModel(ar, 2, RepeatBlock()) # ╔═╡ 5a96e7e9-0376-4365-8eb1-b2fad9be8fef -let - n_samples = 100 - twod_ar_mdl = generate_latent(twod_ar, 30) - twod_ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(twod_ar_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(twod_ar_mdl, θ)[1] - end - - plot(twod_ar_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "$(n_samples) draws from the weekly AR(2) model", - ylabel = "Log Rt") +n_samples = 100 +twod_ar_mdl = generate_latent(twod_ar, 30) +twod_ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ + θ = rand(twod_ar_mdl) #Sample unconditionally the underlying parameters of the model + gen = generated_quantities(twod_ar_mdl, θ)[1] end +plot(twod_ar_mdl_samples, + lab = "", + c = :grey, + alpha = 0.25, + title = "$(n_samples) draws from the weekly AR(2) model", + ylabel = "Log Rt") + # ╔═╡ 6a9e871f-a2fa-4e41-af89-8b0b3c3b5b4b md" ## The Renewal model as an `EpiModel` type @@ -211,16 +205,14 @@ truth_GI = Gamma(6.5, 0.62) model_data = EpiData(gen_distribution = truth_GI) # ╔═╡ 71d08f7e-c409-4fbe-b154-b21d09010683 -let - bar(model_data.gen_int, - fillalpha = 0.5, - lw = 0, - lab = "Discretized next gen pmf", - xticks = 0:14, - xlabel = "Days", - title = "Continuous and discrete generation intervals") - plot!(truth_GI, lab = "Continuous serial interval") -end +bar(model_data.gen_int, + fillalpha = 0.5, + lw = 0, + lab = "Discretized next gen pmf", + xticks = 0:14, + xlabel = "Days", + title = "Continuous and discrete generation intervals") +plot!(truth_GI, lab = "Continuous serial interval") # ╔═╡ 4a2b5cf1-623c-4fe7-8365-49fb7972af5a md" @@ -263,28 +255,8 @@ R_t_fixed = [0.5 + 2.5 / (1 + exp(t - 15)) for t in 1:30] latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed)) # ╔═╡ 7a6d4b14-58d3-40c1-81f2-713c830f875f -plt_epi = let - n_samples = 100 - epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(latent_inf_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(latent_inf_mdl, θ) - end - - p1 = plot(epi_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "$(n_samples) draws from renewal model with chosen Rt", - ylabel = "Latent infections" - ) - p2 = plot(R_t_fixed, - lab = "", - lw = 2, - ylabel = "Rt" - ) - - plot(p1, p2, layout = (2, 1)) -end +n_samples = 100 +e # ╔═╡ c8ef8a60-d087-4ae9-ae92-abeea5afc7ae md" @@ -323,24 +295,22 @@ expected_cases = [1000 * exp(-(t - 15)^2 / (2 * 4)) for t in 1:30] obs_mdl = generate_observations(obs, missing, expected_cases) # ╔═╡ c3a62dda-e054-4c8c-b1b8-ba1b5c4447b3 -plt_obs = let - n_samples = 100 - obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(obs_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(obs_mdl, θ)[1] - end - scatter(obs_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "$(n_samples) draws from neg. bin. obs model", - ylabel = "Observed cases" - ) - plot!(expected_cases, - c = :red, - lw = 3, - lab = "Expected cases") +n_samples = 100 +obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _ + θ = rand(obs_mdl) #Sample unconditionally the underlying parameters of the model + gen = generated_quantities(obs_mdl, θ)[1] end +scatter(obs_mdl_samples, + lab = "", + c = :grey, + alpha = 0.25, + title = "$(n_samples) draws from neg. bin. obs model", + ylabel = "Observed cases" +) +plot!(expected_cases, + c = :red, + lw = 3, + lab = "Expected cases") # ╔═╡ de5d96f0-4df6-4cc3-9f1d-156176b2b676 md"A _reverse_ observation model, which samples the underlying latent infections conditional on observations would require a prior on the latent infections. This is the purpose of composing multiple models; as we'll see below the latent infection and latent $R_t$ models are informative priors on the latent infection time series underlying the observations." @@ -443,56 +413,54 @@ In a future note, we'll demonstrate having a time-varying ascertainment rate. " # ╔═╡ 8b557bf1-f3dd-4f42-a250-ce965412eb32 -let - C = south_korea_data.y_t - D = south_korea_data.dates - gens = inference_results.generated - - #Unconditional model for posterior predictive sampling - mdl_unconditional = generate_epiaware(epi_prob, (y_t = missing,)) - predicted_y_t = mapreduce( - hcat, generated_quantities(mdl_unconditional, inference_results.samples)) do gen - gen.generated_y_t - end - predicted_I_t = mapreduce( - hcat, gens) do gen - gen.I_t - end - predicted_R_t = mapreduce( - hcat, gens) do gen - exp.(gen.Z_t) - end - - p1 = plot(D, predicted_y_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p1, D, C, - lab = "Actual cases", - ylabel = "Daily Cases", - title = "Post. predictive: Cases", - ylims = (-0.5, maximum(C) * 2), - c = :red - ) - - p2 = plot(D, predicted_I_t, - c = :grey, - alpha = 0.05, - lab = "", - ylabel = "Daily latent infections", - ylims = (-0.5, maximum(C) * 1.5), - title = "Prediction: Latent infections" - ) - - p3 = plot(D, predicted_R_t, - c = :grey, - alpha = 0.025, - lab = "", - ylabel = "Rt", - title = "Prediction: Reproduction number", - yscale = :log10 - ) - hline!(p3, [1.0], lab = "Rt = 1", lw = 2, c = :blue) - - plot(p1, p2, p3, layout = (3, 1), size = (500, 700), left_margin = 5mm) +C = south_korea_data.y_t +D = south_korea_data.dates +gens = inference_results.generated + +#Unconditional model for posterior predictive sampling +mdl_unconditional = generate_epiaware(epi_prob, (y_t = missing,)) +predicted_y_t = mapreduce( + hcat, generated_quantities(mdl_unconditional, inference_results.samples)) do gen + gen.generated_y_t +end +predicted_I_t = mapreduce( + hcat, gens) do gen + gen.I_t end +predicted_R_t = mapreduce( + hcat, gens) do gen + exp.(gen.Z_t) +end + +p1 = plot(D, predicted_y_t, c = :grey, alpha = 0.05, lab = "") +scatter!(p1, D, C, + lab = "Actual cases", + ylabel = "Daily Cases", + title = "Post. predictive: Cases", + ylims = (-0.5, maximum(C) * 2), + c = :red +) + +p2 = plot(D, predicted_I_t, + c = :grey, + alpha = 0.05, + lab = "", + ylabel = "Daily latent infections", + ylims = (-0.5, maximum(C) * 1.5), + title = "Prediction: Latent infections" +) + +p3 = plot(D, predicted_R_t, + c = :grey, + alpha = 0.025, + lab = "", + ylabel = "Rt", + title = "Prediction: Reproduction number", + yscale = :log10 +) +hline!(p3, [1.0], lab = "Rt = 1", lw = 2, c = :blue) + +plot(p1, p2, p3, layout = (3, 1), size = (500, 700), left_margin = 5mm) # ╔═╡ c05ed977-7a89-4ac8-97be-7078d69fce9f md" @@ -502,53 +470,51 @@ We can interrogate the sampled chains directly from the `samples` field of the ` " # ╔═╡ ff21c9ec-1581-405f-8db1-0f522b5bc296 -let - p1 = histogram(inference_results.samples["obs.cluster_factor"], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: Neg. bin. cluster factor") - plot!(p1, obs.cluster_factor_prior, - lw = 3, - c = :black, - lab = "prior") - - p2 = histogram(inference_results.samples[:init_incidence], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: log-initial incidence") - plot!(p2, epi.initialisation_prior, - lw = 3, - c = :black, - lab = "prior") - - p3 = histogram(inference_results.samples["latent.damp_AR[1]"], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: rho_1") - plot!(p3, ar.damp_prior.v[1], - lw = 3, - c = :black, - lab = "prior") - - p4 = histogram(inference_results.samples["latent.damp_AR[2]"], - lab = "chain " .* string.([1 2 3 4]), - fillalpha = 0.4, - lw = 0, - norm = :pdf, - title = "Posterior dist: rho_2") - plot!(p4, ar.damp_prior.v[2], - lw = 3, - c = :black, - lab = "prior") - - plot(p1, p2, p3, p4, layout = (2, 2), size = (800, 600)) -end +p1 = histogram(inference_results.samples["obs.cluster_factor"], + lab = "chain " .* string.([1 2 3 4]), + fillalpha = 0.4, + lw = 0, + norm = :pdf, + title = "Posterior dist: Neg. bin. cluster factor") +plot!(p1, obs.cluster_factor_prior, + lw = 3, + c = :black, + lab = "prior") + +p2 = histogram(inference_results.samples[:init_incidence], + lab = "chain " .* string.([1 2 3 4]), + fillalpha = 0.4, + lw = 0, + norm = :pdf, + title = "Posterior dist: log-initial incidence") +plot!(p2, epi.initialisation_prior, + lw = 3, + c = :black, + lab = "prior") + +p3 = histogram(inference_results.samples["latent.damp_AR[1]"], + lab = "chain " .* string.([1 2 3 4]), + fillalpha = 0.4, + lw = 0, + norm = :pdf, + title = "Posterior dist: rho_1") +plot!(p3, ar.damp_prior.v[1], + lw = 3, + c = :black, + lab = "prior") + +p4 = histogram(inference_results.samples["latent.damp_AR[2]"], + lab = "chain " .* string.([1 2 3 4]), + fillalpha = 0.4, + lw = 0, + norm = :pdf, + title = "Posterior dist: rho_2") +plot!(p4, ar.damp_prior.v[2], + lw = 3, + c = :black, + lab = "prior") + +plot(p1, p2, p3, p4, layout = (2, 2), size = (800, 600)) # ╔═╡ Cell order: # ╟─a59d977c-0178-11ef-0063-83e30e0cf9f0