Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EpiModel inner constructor for ingesting distributions directly. #28

Merged
merged 10 commits into from
Feb 12, 2024
2 changes: 1 addition & 1 deletion EpiAware/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "b4d488971893c2da3a7e5ee0a7a6da358a2c3ba6"
project_hash = "852af0e0beaa4accce6cd930983d2709e4f451f1"

[[deps.ADTypes]]
git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245"
Expand Down
1 change: 1 addition & 0 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using Distributions,
export scan,
create_discrete_pmf,
growth_rate_to_reproductive_ratio,
generate_observation_kernel,
EpiModel,
log_daily_infections,
random_walk
Expand Down
38 changes: 30 additions & 8 deletions EpiAware/src/epimodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,42 @@ struct EpiModel{T<:Real} <: AbstractEpiModel
len_gen_int::Integer #length(gen_int) just to save recalc
len_delay_int::Integer #length(delay_int) just to save recalc

#Inner constructor for EpiModel object
#Inner constructors for EpiModel object
function EpiModel(gen_int, delay_int, cluster_coeff, time_horizon)
@assert all(gen_int .>= 0) "Generation interval must be non-negative"
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(gen_int) ≈ 1 "Generation interval must sum to 1"
@assert sum(delay_int) ≈ 1 "Delay interval must sum to 1"
#construct observation delay kernel
K = zeros(time_horizon, time_horizon) |> SparseMatrixCSC
for i = 1:time_horizon, j = 1:time_horizon
m = (i - 1) - (j - 1)
if m >= 1 && m <= length(delay_int)
K[i, j] = delay_int[m]
end
end
K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int)}(
gen_int,
delay_int,
K,
cluster_coeff,
length(gen_int),
length(delay_int),
)
end

function EpiModel(
gen_distribution::ContinuousDistribution,
delay_distribution::ContinuousDistribution,
cluster_coeff,
time_horizon;
Δd = 1.0,
D_gen,
D_delay,
)
gen_int =
create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
delay_int = create_discrete_pmf(delay_distribution, Δd = Δd, D = D_delay)

#construct observation delay kernel
#Recall first element is zero delay
K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int)}(
gen_int,
Expand Down
96 changes: 13 additions & 83 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
@@ -1,98 +1,28 @@
@model function log_infections(
y_t,
::Type{T} = Float64;
epimodel::EpiModel,
latent_process,
latent_process;
latent_process_priors,
transform_function = exp,
n_generate_ahead = 0,
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
pos_shift = 1e-6,
α = missing,
) where {T}

I_t = Vector{T}(undef, gen_length)
mean_case_preds = Vector{T}(undef, gen_length)
data_length = length(y_t)

α ~ Gamma(3, 0.05 / 3)
neg_bin_cluster_factor = missing,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),
seabbs marked this conversation as resolved.
Show resolved Hide resolved
)
#Prior
neg_bin_cluster_factor ~ neg_bin_cluster_factor_prior
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved

#Latent process
@submodel _I_t, latent_process_parameters = latent_process()
time_steps = length(y_t) + n_generate_ahead
@submodel _I_t, latent_process_parameters =
latent_process(data_length; latent_process_priors = latent_process_priors)

#Transform into infections
I_t = transform_function.(_I_t)

#Predictive distribution
mean_case_preds .= epimodel.delay_kernel * I_t
case_pred_dists = mean_case_preds .+ pos_shift .|> μ -> mean_cc_neg_bin(μ, α)

#Likelihood
y_t ~ arraydist(case_pred_dists)

#Generate quantities
return (; I_t, latent_process_parameters)
end

@model function exp_growth_rate(
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
y_t,
::Type{T} = Float64;
epimodel::EpiModel,
latent_process,
transform_function = exp,
pos_shift = 1e-6,
α = missing,
_I_0 = missing,
) where {T}

I_t = Vector{T}(undef, gen_length)
mean_case_preds = Vector{T}(undef, gen_length)
data_length = length(y_t)

α ~ Gamma(3, 0.05 / 3)
_I_0 ~ Normal(0.0, 1.0)

#Latent process
@submodel rt, latent_process_parameters = latent_process()

#Transform into infections
I_t = transform_function.(_I_0 .+ cumsum(rt))

#Predictive distribution
mean_case_preds .= epimodel.delay_kernel * I_t
case_pred_dists = mean_case_preds .+ pos_shift .|> μ -> mean_cc_neg_bin(μ, α)

#Likelihood
y_t ~ arraydist(case_pred_dists)

#Generate quantities
return (; I_t, latent_process_parameters)
end

@model function renewal(
y_t,
::Type{T} = Float64;
epimodel::EpiModel,
latent_process,
transform_function = exp,
pos_shift = 1e-6,
α = missing,
_I_0 = missing,
) where {T}

I_t = Vector{T}(undef, gen_length)
mean_case_preds = Vector{T}(undef, gen_length)
data_length = length(y_t)

α ~ Gamma(3, 0.05 / 3)
_I_0 ~ MvNormal(ones(epimodel.len_gen_int)) #<-- need longer initial for renewal

#Latent process
@submodel Rt, latent_process_parameters = latent_process()

#Transform into infections
I_t, _ = scan(epimodel, transform_function.(_I_0), Rt)

#Predictive distribution
mean_case_preds .= epimodel.delay_kernel * I_t
case_pred_dists = mean_case_preds .+ pos_shift .|> μ -> mean_cc_neg_bin(μ, α)
case_pred_dists =
(epimodel.delay_kernel * I_t) .+ pos_shift .|> μ -> mean_cc_neg_bin(μ, α)

#Likelihood
y_t ~ arraydist(case_pred_dists)
Expand Down
24 changes: 24 additions & 0 deletions EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,27 @@ function mean_cc_neg_bin(μ, α)
r = μ^2 / ex_σ²
return NegativeBinomial(r, p)
end


"""
generate_observation_kernel(delay_int, time_horizon)

Generate an observation kernel matrix based on the given delay interval and time horizon.

# Arguments
- `delay_int::Vector{Float64}`: The delay PMF vector.
- `time_horizon::Int`: The number of time steps of the observation period.

# Returns
- `K::SparseMatrixCSC{Float64, Int}`: The observation kernel matrix.
"""
function generate_observation_kernel(delay_int, time_horizon)
K = zeros(eltype(delay_int), time_horizon, time_horizon) |> SparseMatrixCSC
for i = 1:time_horizon, j = 1:time_horizon
m = i - j
if m >= 0 && m <= (length(delay_int) - 1)
K[i, j] = delay_int[m+1]
end
end
return K
end
2 changes: 1 addition & 1 deletion EpiAware/test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "0b01aa91e53bb772f02a49192dfa1019eaa23f4b"
project_hash = "0dea5a2fa6648a3a05ed8cb24ee73213ffe76d33"

[[deps.ADTypes]]
git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245"
Expand Down
1 change: 1 addition & 0 deletions EpiAware/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Expand Down
63 changes: 63 additions & 0 deletions EpiAware/test/test_epimodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,66 @@ end

@test model(recent_incidence, Rt) == expected_output
end
@testitem "EpiModel constructor" begin
gen_int = [0.2, 0.3, 0.5]
delay_int = [0.1, 0.4, 0.5]
cluster_coeff = 0.8
time_horizon = 10

model = EpiModel(gen_int, delay_int, cluster_coeff, time_horizon)

@test length(model.gen_int) == 3
@test length(model.delay_int) == 3
@test model.cluster_coeff == 0.8
@test model.len_gen_int == 3
@test model.len_delay_int == 3

@test sum(model.gen_int) ≈ 1
@test sum(model.delay_int) ≈ 1

@test size(model.delay_kernel) == (time_horizon, time_horizon)
end

@testitem "EpiModel function" begin
using LinearAlgebra
recent_incidence = [10, 20, 30]
Rt = 1.5

expected_new_incidence = Rt * dot(recent_incidence, [0.2, 0.3, 0.5])
expected_output =
[expected_new_incidence; recent_incidence[1:2]], expected_new_incidence

model = EpiModel([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10)

@test model(recent_incidence, Rt) == expected_output
end

@testitem "EpiModel constructor with distributions" begin
using Distributions

gen_distribution = Uniform(0.0, 10.0)
delay_distribution = Exponential(1.0)
cluster_coeff = 0.8
time_horizon = 10
D_gen = 10.0
D_delay = 10.0
Δd = 1.0

model = EpiModel(
gen_distribution,
delay_distribution,
cluster_coeff,
time_horizon;
D_gen = 10.0,
D_delay = 10.0,
)

@test model.cluster_coeff == 0.8
@test model.len_gen_int == Int64(D_gen / Δd) - 1
@test model.len_delay_int == Int64(D_delay / Δd)

@test sum(model.gen_int) ≈ 1
@test sum(model.delay_int) ≈ 1

@test size(model.delay_kernel) == (time_horizon, time_horizon)
end
22 changes: 21 additions & 1 deletion EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end

end

@testitem"Testing growth_rate_to_reproductive_ratio function" begin
@testitem "Testing growth_rate_to_reproductive_ratio function" begin
#Test that zero exp growth rate imples R0 = 1
@testset "Test case 1" begin
r = 0
Expand All @@ -99,3 +99,23 @@ end
end

end

@testitem "Testing generate_observation_kernel function" begin
using SparseArrays
@testset "Test case 1" begin
delay_int = [0.2, 0.5, 0.3]
time_horizon = 5
expected_K = SparseMatrixCSC(
[
0.2 0 0 0 0
0.5 0.2 0 0 0
0.3 0.5 0.2 0 0
0 0.3 0.5 0.2 0
0 0 0.3 0.5 0.2
],
)
K = generate_observation_kernel(delay_int, time_horizon)
@test K == expected_K
end

end
Loading