diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 1fb193a21..cc2047c13 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -55,9 +55,10 @@ include("EpiInference/EpiInference.jl") @reexport using .EpiInference # Non-submodule exports -export make_epi_aware +export make_epi_aware, EpiAwareProblem include("docstrings.jl") +include("epiawareproblems/epiawareprob.jl") include("make_epi_aware.jl") end diff --git a/EpiAware/src/EpiAwareBase/EpiAwareBase.jl b/EpiAware/src/EpiAwareBase/EpiAwareBase.jl index 2d803a58f..ee6d5fb72 100644 --- a/EpiAware/src/EpiAwareBase/EpiAwareBase.jl +++ b/EpiAware/src/EpiAwareBase/EpiAwareBase.jl @@ -7,7 +7,7 @@ Module for defining abstract epidemiological types. using DocStringExtensions export AbstractModel, AbstractEpiModel, AbstractLatentModel, - AbstractObservationModel, generate_latent, + AbstractObservationModel, AbstractEpiAwareProblem, generate_latent, generate_latent_infs, generate_observations include("docstrings.jl") diff --git a/EpiAware/src/EpiAwareBase/types.jl b/EpiAware/src/EpiAwareBase/types.jl index f69af6d9a..c65258fac 100644 --- a/EpiAware/src/EpiAwareBase/types.jl +++ b/EpiAware/src/EpiAwareBase/types.jl @@ -11,3 +11,8 @@ abstract type AbstractEpiModel <: AbstractModel end abstract type AbstractLatentModel <: AbstractModel end abstract type AbstractObservationModel <: AbstractModel end + +""" +Abstract supertype for all `EpiAware` problems. +""" +abstract type AbstractEpiAwareProblem end diff --git a/EpiAware/src/epiawareproblems/epiawareprob.jl b/EpiAware/src/epiawareproblems/epiawareprob.jl new file mode 100644 index 000000000..1e1abdcfc --- /dev/null +++ b/EpiAware/src/epiawareproblems/epiawareprob.jl @@ -0,0 +1,22 @@ +""" +Defines an inference/generative modelling problem for case data. + +`EpiAwareProblem` wraps the underlying components of an epidemiological model: +- `epi_model`: An epidemiological model for unobserved infections. +- `latent_model`: A latent model for underlying latent process. +- `observation_model`: An observation model for observed cases. + +Along with a `tspan` tuple for the time span of the case data. +""" +@kwdef struct EpiAwareProblem{ + E <: AbstractEpiModel, L <: AbstractLatentModel, O <: AbstractObservationModel} <: + AbstractEpiAwareProblem + "Epidemiological model for unobserved infections." + epi_model::E + "Latent model for underlying latent process." + latent_model::L + "Observation model for observed cases." + observation_model::O + "Time span for either inference or generative modelling of case time series." + tspan::Tuple{Int, Int} +end diff --git a/EpiAware/src/make_epi_aware.jl b/EpiAware/src/make_epi_aware.jl index 2e1959405..df50c0cbe 100644 --- a/EpiAware/src/make_epi_aware.jl +++ b/EpiAware/src/make_epi_aware.jl @@ -21,6 +21,6 @@ return (; generated_y_t, I_t, - latent_model, + Z_t, process_aux = merge(latent_model_aux, generated_y_t_aux)) end diff --git a/EpiAware/test/test_epiawareprob.jl b/EpiAware/test/test_epiawareprob.jl new file mode 100644 index 000000000..fb447a1a3 --- /dev/null +++ b/EpiAware/test/test_epiawareprob.jl @@ -0,0 +1,23 @@ +@testitem "EpiAwareProblem Tests" begin + using Distributions + # Define test inputs + data = EpiData([0.2, 0.3, 0.5], exp) + epi_model = DirectInfections(data, Normal()) + latent_model = RandomWalk(Normal(0.0, 1.0), truncated(Normal(0.0, 0.05), 0.0, Inf)) + delay_int = [0.2, 0.3, 0.5] + time_horizon = 30 + obs_prior = default_delay_obs_priors() + + obs_model = DelayObservations(delay_int, time_horizon, + obs_prior[:neg_bin_cluster_factor_prior]) + tspan = (0, 365) + + # Create an instance of EpiAwareProblem + problem = EpiAwareProblem(epi_model, latent_model, obs_model, tspan) + + @test typeof(problem) <: EpiAwareProblem + @test typeof(problem.epi_model) <: DirectInfections + @test typeof(problem.latent_model) <: RandomWalk + @test typeof(problem.observation_model) <: DelayObservations + @test problem.tspan == (0, 365) +end diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl index d86184da0..38688ac21 100644 --- a/EpiAware/test/test_inference-methods.jl +++ b/EpiAware/test/test_inference-methods.jl @@ -21,7 +21,8 @@ end @testset "Test case: check fail mode for bad model" begin @model function bad_model() - x ~ Normal(0, 1) + x ~ truncated(Normal(0, 1), -Inf, -1e-6) + y ~ Normal(sqrt(x), 1) #<-fails return sqrt(x) #<-fails end badmdl = bad_model()