Skip to content

Commit

Permalink
Fix broken latentdiffmodel test and switch off progress reports in te…
Browse files Browse the repository at this point in the history
…st sampling (#153)

* Fix difflatentmodel test

* switch `sample` to silent progress
  • Loading branch information
SamuelBrand1 authored Mar 15, 2024
1 parent 4b6ca67 commit b27d598
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion EpiAware/test/test_autoregressive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ end
fixed_model = fix(model, (σ_AR = σ_AR, damp_AR = damp, ar_init = ar_init))

n_samples = 100
samples = sample(fixed_model, Prior(), n_samples) |>
samples = sample(fixed_model, Prior(), n_samples; progress = false) |>
chn -> mapreduce(vcat, generated_quantities(fixed_model, chn)) do gen
gen[1]
end
Expand Down
7 changes: 4 additions & 3 deletions EpiAware/test/test_difflatentmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
latent_model, (σ_RW = 0, rw_init = 0.0))

n_samples = 2000
samples = sample(fixed_model, Prior(), n_samples) |>
samples = sample(fixed_model, Prior(), n_samples; progress = false) |>
chn -> mapreduce(hcat, generated_quantities(fixed_model, chn)) do gen
gen[1]
end
Expand All @@ -56,7 +56,8 @@ end
end

#Plus day two distribution
day2_dist = foldl((x, y) -> _add_normals(x, init_priors[1]), 1:d, init = init_priors[2])
day2_dist = _add_normals(
Normal(d * init_priors[1].μ, d * init_priors[1].σ), init_priors[2])

ks_test_pval_day1 = ExactOneSampleKSTest(samples[1, :], init_priors[1]) |> pvalue
ks_test_pval_day2 = ExactOneSampleKSTest(samples[2, :], day2_dist) |> pvalue
Expand All @@ -81,7 +82,7 @@ end
(latent_init = [0.0, 1.0], σ_AR = 1.0, damp_AR = [0.8], ar_init = [0.0]))

n_samples = 100
samples = sample(fixed_model, Prior(), n_samples) |>
samples = sample(fixed_model, Prior(), n_samples; progress = false) |>
chn -> mapreduce(hcat, generated_quantities(fixed_model, chn)) do gen
gen[1]
end
Expand Down
7 changes: 4 additions & 3 deletions EpiAware/test/test_epi-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ end
rt = [log(recent_incidence[1]) - log_init; diff(log.(recent_incidence))]

#Check log_init is sampled from the correct distribution
sample_init_inc = sample(generate_latent_infs(rt_model, rt), Prior(), 1000) |>
sample_init_inc = sample(
generate_latent_infs(rt_model, rt), Prior(), 1000; progress = false) |>
chn -> chn[:init_incidence] |>
Array |>
vec
Expand Down Expand Up @@ -100,7 +101,7 @@ end
#Check log_init is sampled from the correct distribution
sample_init_inc = sample(
generate_latent_infs(direct_inf_model, log_incidence),
Prior(), 1000) |>
Prior(), 1000; progress = false) |>
chn -> chn[:init_incidence] |>
Array |>
vec
Expand Down Expand Up @@ -143,7 +144,7 @@ end

#Check log_init is sampled from the correct distribution
@time sample_init_inc = sample(generate_latent_infs(renewal_model, log_Rt),
Prior(), 1000) |>
Prior(), 1000; progress = false) |>
chn -> chn[:init_incidence] |>
Array |>
vec
Expand Down
4 changes: 2 additions & 2 deletions EpiAware/test/test_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end
latent_model = rwp,
observation_model = obs_model)

chn = sample(test_mdl, Prior(), 1000)
chn = sample(test_mdl, Prior(), 1000; progress = false)
gens = generated_quantities(test_mdl, chn)

#Check model sampled
Expand Down Expand Up @@ -108,7 +108,7 @@ end
observation_model = obs_model
)

chn = sample(test_mdl, Prior(), 1000)
chn = sample(test_mdl, Prior(), 1000; progress = false)
gens = generated_quantities(test_mdl, chn)

#Check model sampled
Expand Down
4 changes: 2 additions & 2 deletions EpiAware/test/test_observation-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
fix_mdl = fix(mdl, (neg_bin_cluster_factor = neg_bin_cf,))

n_samples = 1000
first_obs = sample(fix_mdl, Prior(), n_samples) |>
first_obs = sample(fix_mdl, Prior(), n_samples; progress = false) |>
chn -> generated_quantities(fix_mdl, chn) .|>
(gen -> gen[1][1]) |>
vec
Expand Down Expand Up @@ -64,7 +64,7 @@ end
@testset "$scenario_name y_t" begin
mdl = generate_observations(
delay_obs, y_t_scenario, I_t)
sampled_obs = sample(mdl, Prior(), 1000) |>
sampled_obs = sample(mdl, Prior(), 1000; progress = false) |>
chn -> generated_quantities(mdl, chn) .|>
(gen -> gen[1]) |>
collect
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/test_randomwalk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model = generate_latent(rw_process, n)
fixed_model = fix(model, (σ_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
samples_day_5 = sample(fixed_model, Prior(), n_samples) |>
samples_day_5 = sample(fixed_model, Prior(), n_samples; progress = false) |>
chn -> mapreduce(vcat, generated_quantities(fixed_model, chn)) do gen
gen[1][5] #Extracting day 5 samples
end
Expand Down

0 comments on commit b27d598

Please sign in to comment.