Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 251: Add aggregation functionality #397

Merged
merged 11 commits into from
Jul 25, 2024
2 changes: 1 addition & 1 deletion EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export CombineLatentModels, ConcatLatentModels, BroadcastLatentModel
export RepeatEach, RepeatBlock

# Export helper functions
export broadcast_dayofweek, broadcast_weekly, equal_dimensions
export broadcast_rule, broadcast_dayofweek, broadcast_weekly, equal_dimensions

# Export tools for modifying latent models
export DiffLatentModel, TransformLatentModel, PrefixLatentModel, RecordExpectedLatent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ Generates latent periods using the specified `model` and `n` number of samples.

## Returns
- `broadcasted_latent`: The generated broadcasted latent periods.
- `latent_period_aux...`: Additional auxiliary information about the latent periods.

"
@model function EpiAwareBase.generate_latent(model::BroadcastLatentModel, n)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ It repeats the latent process at each period. An example of this rule is to repe
```julia
using EpiAware
rule = RepeatEach()
latent = [1, 2, 3]
latent = [1, 2]
n = 10
period = 2
broadcast_rule(rule, latent, n, period)
Expand Down
5 changes: 4 additions & 1 deletion EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using ..EpiAwareBase

using ..EpiAwareUtils

using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek, PrefixLatentModel
using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
using ..EpiLatentModels: broadcast_rule, PrefixLatentModel, RepeatEach

using Turing, Distributions, DocStringExtensions, SparseArrays, LinearAlgebra

Expand All @@ -19,6 +20,7 @@ export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, PrefixObservationModel, RecordExpectedObs
export Aggregate

# Observation model manipulators
export StackObservationModels
Expand All @@ -30,6 +32,7 @@ include("docstrings.jl")
include("modifiers/LatentDelay.jl")
include("modifiers/ascertainment/Ascertainment.jl")
include("modifiers/ascertainment/helpers.jl")
include("modifiers/Aggregate.jl")
include("modifiers/PrefixObservationModel.jl")
include("modifiers/RecordExpectedObs.jl")
include("StackObservationModels.jl")
Expand Down
67 changes: 67 additions & 0 deletions EpiAware/src/EpiObsModels/modifiers/Aggregate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
@doc raw"
Aggregates observations over a specified time period. For efficiency it also only passes the aggregated observations to the submodel. The aggregation vector
is internally broadcasted to the length of the observations and the present vector is broadcasted to the length of the aggregation vector using `broadcast_n`.

# Fields

- `model::AbstractTuringObservationModel`: The submodel to use for the aggregated observations.
- `aggregation::AbstractVector{<: Int}`: The number of time periods to aggregate over.
- `present::AbstractVector{<: Bool}`: A vector of booleans indicating whether the observation is present or not.

# Constructors

- `Aggregate(model, aggregation)`: Constructs an `Aggregate` object and automatically sets the `present` field.
- `Aggregate(; model, aggregation)`: Constructs an `Aggregate` object and automatically sets the `present` field using named keyword arguments

# Examples

```julia
using EpiAware
weekly_agg = Aggregate(PoissonError(), [0, 0, 0, 0, 7, 0, 0])
gen_obs = generate_observations(weekly_agg, missing, fill(1, 28))
gen_obs()
```
"
struct Aggregate{M <: AbstractTuringObservationModel,
I <: AbstractVector{<:Int}, J <: AbstractVector{<:Bool}} <:
AbstractTuringObservationModel
model::M
aggregation::I
present::J

function Aggregate(model, aggregation)
present = aggregation .!= 0
new{typeof(model), typeof(aggregation), typeof(present)}(
model, aggregation, present)
end

function Aggregate(; model, aggregation)
return Aggregate(model, aggregation)
end
end

@model function EpiAwareBase.generate_observations(ag::Aggregate, y_t, Y_t)
if ismissing(y_t)
y_t = Vector{Missing}(missing, length(Y_t))
end

n = length(y_t)
m = length(ag.aggregation)

aggregation = broadcast_rule(RepeatEach(), ag.aggregation, n, m)
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved

present = broadcast_rule(RepeatEach(), ag.present, n, m)

agg_Y_t = map(findall(present)) do i
sum(Y_t[max(1, i - aggregation[i] + 1):i])
end

@submodel pred_obs = generate_observations(ag.model, y_t[present], agg_Y_t)
return _return_aggregate(pred_obs, present, n)
end

function _return_aggregate(pred_obs, present, n)
y_t = zeros(eltype(pred_obs), n)
y_t[present] = pred_obs
return y_t
end
30 changes: 30 additions & 0 deletions EpiAware/test/EpiObsModels/modifiers/Aggregate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@testitem "Aggregate constructor works as expected" begin
weekly_agg = Aggregate(PoissonError(), [0, 0, 0, 0, 7, 0, 0])
@test weekly_agg.model == PoissonError()
@test weekly_agg.aggregation == [0, 0, 0, 0, 7, 0, 0]
@test weekly_agg.present == [false, false, false, false, true, false, false]

weekly_agg = Aggregate(model = PoissonError(), aggregation = [0, 0, 0, 0, 7, 0, 0])
@test weekly_agg.model == PoissonError()
@test weekly_agg.aggregation == [0, 0, 0, 0, 7, 0, 0]
end

@testitem "Aggregate generate_observations works as expected" begin
using Turing
struct TestObs <: AbstractTuringObservationModel end

@model function EpiAwareBase.generate_observations(::TestObs, y_t, Y_t)
return Y_t
end
weekly_agg = Aggregate(TestObs(), [0, 0, 0, 0, 7, 0, 0])
gen_obs = generate_observations(weekly_agg, missing, fill(1, 28))
draws = gen_obs()
@test draws isa Vector{Int64}
@test length(draws) == 28
exp_draws = fill(0.0, 28)
exp_draws[5] = 5.0
exp_draws[12] = 7.0
exp_draws[19] = 7.0
exp_draws[26] = 7.0
@test draws == exp_draws
end
6 changes: 6 additions & 0 deletions benchmark/bench/EpiObsModels/modifiers/Aggregate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
let
seabbs marked this conversation as resolved.
Show resolved Hide resolved
I_t = fill(10, 100)
weekly_agg = Aggregate(PoissonError(), [0, 0, 0, 0, 7, 0, 0])
mdl = generate_observations(weekly_agg, I_t, I_t)
suite["Aggregate"] = make_epiaware_suite(mdl)
end
Loading