Skip to content

Commit

Permalink
commit all
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Mar 12, 2024
1 parent 3b2d479 commit 085c4b9
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 50 deletions.
5 changes: 3 additions & 2 deletions EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ Z_0 &\sim \mathcal{N}(0,1),\\
"

# ╔═╡ 56ae496b-0094-460b-89cb-526627991717
rwp = EpiAware.RandomWalk(Normal(),
EpiAware._make_halfnormal_prior(0.1))
rwp = EpiAware.RandomWalk(
init_prior = Normal(),
std_prior = EpiAware._make_halfnormal_prior(0.1))

# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44
md"
Expand Down
3 changes: 1 addition & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ export make_epi_aware
export generate_latent, generate_latent_infs, generate_observations

# Exported utilities
export create_discrete_pmf, spread_draws, scan, R_to_r, r_to_R,
default_rw_priors, default_delay_obs_priors
export create_discrete_pmf, spread_draws, scan, R_to_r, r_to_R, default_delay_obs_priors

# Exported inference methods
export manypathfinder
Expand Down
33 changes: 26 additions & 7 deletions EpiAware/src/abstract-types.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
abstract type AbstractModel end

"""
abstract type AbstractEpiModel <: AbstractModel end
The abstract supertype for all structs that define a model for generating unobserved/latent
infections.
infections.
"""
abstract type AbstractEpiModel <: AbstractModel end

"""
The abstract supertype for all structs that define a model for generating a latent process
used in `EpiAware` models.
"""
abstract type AbstractLatentModel <: AbstractModel end

abstract type AbstractObservationModel <: AbstractModel end

@doc raw"""
Generate unobserved/latent infections based on the given `epi_model <: AbstractEpimodel`
and a latent process path ``Z_t``.
Constructor function for unobserved/latent infections based on the type of
`epi_model <: AbstractEpimodel` and a latent process path ``Z_t``.
The `generate_latent_infs` function implements a model of generating unobserved/latent
infections conditional on a latent process. Which model of generating unobserved/latent
Expand All @@ -25,15 +27,32 @@ defined for the given `epi_model`, then `EpiAware` will return a warning and ret
## Interface to `Turing.jl` probablilistic programming language (PPL)
Apart from the no implementation fallback method, the `generate_latent_infs` implementation
function should be a constructor function for a
function returns a constructor function for a
[`DynamicPPL.Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model)
object.
object where the unobserved/latent infections are a generated quantity. Priors for model
parameters are fields of `epi_model`.
"""
function generate_latent_infs(epi_model::AbstractEpiModel, Z_t)
@warn "No concrete implementation for `generate_latent_infs` is defined."
return nothing
end

@doc raw"""
Constructor function for a latent process path ``Z_t`` of length `n`.
The `generate_latent` function implements a model of generating a latent process. Which
model for generating the latent process infections is implemented is set by the type of
`latent_model`. If no implemention is defined for the type of `latent_model`, then
`EpiAware` will pass a warning and return `nothing`.
## Interface to `Turing.jl` probablilistic programming language (PPL)
Apart from the no implementation fallback method, the `generate_latent` implementation
function should return a constructor function for a
[`DynamicPPL.Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model)
object. Sample paths of ``Z_t`` are generated quantities of the constructed model. Priors
for model parameters are fields of `epi_model`.
"""
function generate_latent(latent_model::AbstractLatentModel, n)
@info "No concrete implementation for generate_latent is defined."
return nothing
Expand Down
13 changes: 7 additions & 6 deletions EpiAware/src/epimodels/directinfections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ I_t = g(\hat{I}_0 + Z_t).
```
where ``g`` is a transformation function and the unconstrained initial infections
``\hat{I}_0`` are sampled from a prior distribution, `initialisation_prior` which must
be supplied to the `DirectInfections` constructor. The default `initialisation_prior` is
the standard Normal `Distributions.Normal()`.
``\hat{I}_0`` are sampled from a prior distribution.
## Constructor
`DirectInfections` are constructed by passing an `EpiData` object `data` and an
`initialisation_prior` for the prior distribution of ``\hat{I}_0``. The default
`initialisation_prior` is `Normal()`.
`DirectInfections` can be constructed by passing an `EpiData` object and subtype of
[`Distributions.Sampleable`](https://juliastats.org/Distributions.jl/latest/types/#Sampleable).
## Constructors
- `DirectInfections(; data, initialisation_prior)`
## Example usage with `generate_latent_infs`
Expand Down
11 changes: 6 additions & 5 deletions EpiAware/src/epimodels/expgrowthrate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ I_t = g(\hat{I}_0) \exp(Z_t).
```
where ``g`` is a transformation function and the unconstrained initial infections
``\hat{I}_0`` are sampled from a prior distribution, `initialisation_prior` which must
be supplied to the `DirectInfections` constructor. The default `initialisation_prior` is
the standard Normal `Distributions.Normal()`.
``\hat{I}_0`` are sampled from a prior distribution.
`ExpGrowthRate` are constructed by passing an `EpiData` object `data` and an
`initialisation_prior` for the prior distribution of ``\hat{I}_0``. The default
`initialisation_prior` is `Normal()`.
## Constructor
`ExpGrowthRate` can be constructed by passing an `EpiData` object and and subtype of
[`Distributions.Sampleable`](https://juliastats.org/Distributions.jl/latest/types/#Sampleable).
- `ExpGrowthRate(; data, initialisation_prior)`.
## Example usage with `generate_latent_infs`
Expand Down
15 changes: 9 additions & 6 deletions EpiAware/src/epimodels/renewal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ I_t &= g(\hat{I}_0) \exp(r(\mathcal{R}_1) t), \qquad t \leq 0.
```
where ``g`` is a transformation function and the unconstrained initial infections
``\hat{I}_0`` are sampled from a prior distribution, `initialisation_prior` which must
be supplied to the `DirectInfections` constructor. The default `initialisation_prior` is
the standard Normal `Distributions.Normal()`. The discrete generation interval is given by
``g_i``.
``\hat{I}_0`` are sampled from a prior distribution. The discrete generation interval is
given by ``g_i``.
``r(\mathcal{R}_1)`` is the exponential growth rate implied by ``\mathcal{R}_1)``
using the implicit relationship between the exponential growth rate and the reproduction
Expand All @@ -29,10 +27,13 @@ number.
\mathcal{R} \sum_{j \geq 1} g_j \exp(- r j)= 1.
```
`Renewal` are constructed by passing an `EpiData` object `data` and an
`initialisation_prior` for the prior distribution of ``\hat{I}_0``. The default
`initialisation_prior` is `Normal()`.
## Constructor
`Renewal` can be constructed by passing an `EpiData` object and and subtype of
[`Distributions.Sampleable`](https://juliastats.org/Distributions.jl/latest/types/#Sampleable).
- `Renewal(; data, initialisation_prior)`.
## Example usage with `generate_latent_infs`
Expand Down Expand Up @@ -82,6 +83,8 @@ I_t = generated_quantities(latent_inf, θ)
end

@doc """
function (epi_model::Renewal)(recent_incidence, Rt)
Callable on a `Renewal` struct for compute new incidence based on recent incidence and Rt.
## Mathematical specification
Expand Down
94 changes: 87 additions & 7 deletions EpiAware/src/latentmodels/randomwalk.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,93 @@
struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
init_prior::D
std_prior::S
end
@doc raw"
Model latent process ``Z_t`` as a random walk.
## Mathematical specification
The random walk ``Z_t`` is specified as a parameteric transformation of the white noise
sequence ``(\epsilon_t)_{t\geq 1}``,
```math
Z_t = Z_0 + \sigma \sum_{i = 1}^t \epsilon_t
```
Constructing a random walk requires specifying:
- An `init_prior` as a prior for ``Z_0``. Default is `Normal()`.
- A `std_prior` for ``\sigma``. The default is HalfNormal with a mean of 0.25.
## Constructors
- `RandomWalk(; init_prior, std_prior)`
## Example usage with `generate_latent`
`generate_latent` can be used to construct a `Turing` model for the random walk ``Z_t``.
First, we construct a `RandomWalk` struct with priors,
```julia
using Distributions, Turing, EpiAware
# Create a RandomWalk model
rw = RandomWalk(init_prior = Normal(2., 1.),
std_prior = _make_halfnormal_prior(0.1))
```
function default_rw_priors()
return (:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf),
:init_rw_value_prior => Normal()) |> Dict
Then, we can use `generate_latent` to construct a Turing model for a 10 step random walk.
```julia
# Construct a Turing model
rw_model = generate_latent(rw, 10)
```
Now we can use the `Turing` PPL API to sample underlying parameters and generate the
unobserved infections.
```julia
#Sample random parameters from prior
θ = rand(rw_model)
#Get random walk sample path as a generated quantities from the model
Z_t, _ = generated_quantities(rw_model, θ)
```
"
@kwdef struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
"Prior for the initial distribution of the random walk."
init_prior::D = Normal()
"Prior for the standard deviation of the random walk step size."
std_prior::S = _make_halfnormal_prior(0.25)
end

"""
Implement the `generate_latent` function for the `RandomWalk` model.
## Example usage of `generate_latent` with `RandomWalk` type of latent process model
```julia
using Distributions, Turing, EpiAware
# Create a RandomWalk model
rw = RandomWalk(init_prior = Normal(2., 1.),
std_prior = _make_halfnormal_prior(0.1))
```
Then, we can use `generate_latent` to construct a Turing model for a 10 step random walk.
```julia
# Construct a Turing model
rw_model = generate_latent(rw, 10)
```
Now we can use the `Turing` PPL API to sample underlying parameters and generate the
unobserved infections.
```julia
#Sample random parameters from prior
θ = rand(rw_model)
#Get random walk sample path as a generated quantities from the model
Z_t, _ = generated_quantities(rw_model, θ)
```
"""
@model function generate_latent(latent_model::RandomWalk, n)
ϵ_t ~ MvNormal(ones(n))
σ_RW ~ latent_model.std_prior
Expand Down
18 changes: 3 additions & 15 deletions EpiAware/test/test_latent-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
using HypothesisTests: ExactOneSampleKSTest, pvalue

n = 5
priors = EpiAware.default_rw_priors()
rw_process = EpiAware.RandomWalk(Normal(0.0, 1.0),
truncated(Normal(0.0, 0.05), 0.0, Inf))
rw_process = EpiAware.RandomWalk(
init_prior = Normal(0.0, 1.0),
std_prior = truncated(Normal(0.0, 0.05), 0.0, Inf))
model = EpiAware.generate_latent(rw_process, n)
fixed_model = fix(model, (σ_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
Expand All @@ -18,19 +18,7 @@
ks_test_pval = ExactOneSampleKSTest(samples_day_5, Normal(0.0, sqrt(5))) |> pvalue
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented
end
@testitem "Testing default_rw_priors" begin
@testset "var_RW_prior" begin
priors = EpiAware.default_rw_priors()
var_RW = rand(priors[:var_RW_prior])
@test var_RW >= 0.0
end

@testset "init_rw_value_prior" begin
priors = EpiAware.default_rw_priors()
init_rw_value = rand(priors[:init_rw_value_prior])
@test typeof(init_rw_value) == Float64
end
end
@testset "Testing RandomWalk constructor" begin
init_prior = Normal(0.0, 1.0)
std_prior = truncated(Normal(0.0, 0.05), 0.0, Inf)
Expand Down

0 comments on commit 085c4b9

Please sign in to comment.