Skip to content

Commit

Permalink
Issue 357: Move from fixed n with missing to flexible n (#392)
Browse files Browse the repository at this point in the history
* initial change

* fix tests

* Add a test for new truncated Y_t support

* expand unit tests of latent delay change

* add doc string

* add integration tests for EpiAwareObs after local tests
  • Loading branch information
seabbs authored Jul 23, 2024
1 parent d4d46df commit db599a2
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 22 deletions.
15 changes: 9 additions & 6 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
@doc raw"
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues.
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and expected observations (`Y_t`) that are shorter than observations. When this is the case it assumes that the
expected observations are the last `length(Y_t)` elements of `y_t`.
It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues.
It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required.
"
Expand All @@ -10,15 +12,16 @@ It dispatches to the `observation_error` function to generate the observation er
@submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t)

if ismissing(y_t)
y_t = Vector{Union{Real, Missing}}(missing, length(Y_t))
else
@assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length."
y_t = Vector{Missing}(missing, length(Y_t))
end

diff_t = length(y_t) - length(Y_t)
@assert diff_t>=0 "The observation vector must be longer than or equal to the expected observation vector"

pad_Y_t = Y_t .+ 1e-6

for i in findfirst(!ismissing, Y_t):length(Y_t)
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...)
for i in eachindex(Y_t)
y_t[i + diff_t] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t
Expand Down
18 changes: 9 additions & 9 deletions EpiAware/src/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@doc raw"
The `LatentDelay` struct represents an observation model that introduces a latent delay in the observations. It is a subtype of `AbstractTuringObservationModel`.
Note that the `LatentDelay` observation model shortens the observation vector by the length of the delay distribution and this is then passed to the underlying observation model. This is to prevent fitting to partially
Note that the `LatentDelay` observation model shortens the expected observation vector by the length of the delay distribution and this is then passed to the underlying observation model. This is to prevent fitting to partially
observed data.
## Fields
Expand Down Expand Up @@ -61,21 +61,21 @@ Generates observations based on the `LatentDelay` observation model.
"
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
first_Y_t = findfirst(!ismissing, Y_t)
trunc_Y_t = Y_t[first_Y_t:end]
if ismissing(y_t)
y_t = Vector{Missing}(missing, length(Y_t))
end

pmf_length = length(obs_model.rev_pmf)
@assert pmf_length<=length(trunc_Y_t) "The delay PMF must be shorter than or equal to the observation vector"
@assert pmf_length<=length(Y_t) "The delay PMF must be shorter than or equal to the observation vector"

expected_obs = accumulate_scan(
LDStep(obs_model.rev_pmf),
(; val = 0, current = trunc_Y_t[1:(pmf_length)]),
vcat(trunc_Y_t[(pmf_length + 1):end], 0.0)
(; val = 0, current = Y_t[1:(pmf_length)]),
vcat(Y_t[(pmf_length + 1):end], 0.0)
)

complete_obs = vcat(fill(missing, pmf_length + first_Y_t - 2), expected_obs)

@submodel y_t = generate_observations(
obs_model.model, y_t, complete_obs)
obs_model.model, y_t, expected_obs)

return y_t
end
Expand Down
11 changes: 10 additions & 1 deletion EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,14 @@
@test isapprox(draw[1], 20, atol = 1e-3)
end

@test_throws AssertionError generate_observations(obs_model, vcat(1, I_t), I_t)()
@testset "Test works with truncated expected observations" begin
mdl = generate_observations(obs_model, fill(missing, 5), I_t[(end - 3):end])
draw = mdl()
@test all(map(zip(draw[(end - 3):end], I_t[(end - 3):end])) do (draw, I_t)
isapprox(draw, I_t, atol = 1e-3)
end)
@test ismissing(draw[1])
end

@test_throws AssertionError generate_observations(obs_model, I_t, vcat(1, I_t))()
end
23 changes: 17 additions & 6 deletions EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,31 @@ end

@testset "Test with entirely missing data" begin
mdl = generate_observations(obs_model, missing, I_t)
@test mdl()[3:end] == expected_obs[3:end]
@test sum(mdl() .|> ismissing) == 2
@test mdl() == expected_obs[3:end]
@test sum(mdl() .|> ismissing) == 0
end

@testset "Test with missing data defined as a vector" begin
mdl = generate_observations(
obs_model, [missing, missing, missing, missing, missing], I_t)
@test mdl()[3:end] == expected_obs[3:end]
@test sum(mdl() .|> ismissing) == 2
@test mdl() == expected_obs[3:end]
@test sum(mdl() .|> ismissing) == 0
end

@testset "Test with data" begin
pois_obs_model = LatentDelay(PoissonError(), delay_int)
@testset "Test with a real observation error model" begin
using Turing, DynamicPPL
pois_obs_model = LatentDelay(RecordExpectedObs(PoissonError()), delay_int)
missing_mdl = generate_observations(pois_obs_model, missing, I_t)
missing_draws = missing_mdl()
@test all(ismissing.(missing_draws[1:2]))
@test !any(ismissing.(missing_draws[3:end]))

mdl = generate_observations(pois_obs_model, [10.0, 20.0, 30.0, 40.0, 50.0], I_t)
@test mdl() == [10.0, 20.0, 30.0, 40.0, 50]
samples = sample(mdl, Prior(), 10; progress = false)
exp_y_t = get(samples, :exp_y_t).exp_y_t
@test exp_y_t[1][1] == expected_obs[3]
@test exp_y_t[2][1] == expected_obs[4]
@test exp_y_t[3][1] == expected_obs[5]
end
end
52 changes: 52 additions & 0 deletions benchmark/bench/EpiObsModels/integration.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
suite["Integration"] = BenchmarkGroup()

let
using Distributions

obs = LatentDelay(
Ascertainment(
NegativeBinomialError(), Intercept(Normal(0, 1)); link = x -> exp.(x)
),
LogNormal(1.6, 0.2)
)
I_t = fill(100, 10)
gen_obs = generate_observations(obs, I_t, I_t)
suite["Integration"]["LatentDelay - Ascertainment"] = make_epiaware_suite(
gen_obs)
end

let
I_t = fill(10, 100)
delay_obs = LatentDelay(
LatentDelay(
NegativeBinomialError(),
[0.1, 0.2, 0.3, 0.4]
),
LogNormal(1.4, 0.2)
)
mdl = generate_observations(delay_obs, I_t, I_t)
suite["Integration"]["LatentDelay-LatentDelay"] = make_epiaware_suite(mdl)
end

let
obs = StackObservationModels(
[
Ascertainment(
NegativeBinomialError(),
Intercept(Normal(0.5, 0.1))
),
LatentDelay(
PoissonError(),
[0.1, 0.2, 0.3, 0.4]
)
],
["cases", "deaths"]
)

Y_t = fill(10, 10)
y_t = (cases = Y_t, deaths = Y_t)

gen_obs = generate_observations(obs, y_t, Y_t)

suite["Integration"]["StackObservationModels-LatentDelay-Ascertainment"] = make_epiaware_suite(gen_obs)
end

0 comments on commit db599a2

Please sign in to comment.