Skip to content

Commit

Permalink
Switch to temporal storage for FB
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 11, 2023
1 parent 12fa976 commit 1b9a7f1
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 105 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[weakdeps]
Expand Down
12 changes: 6 additions & 6 deletions ext/HiddenMarkovModelsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ function ChainRulesCore.rrule(
(p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq)
fb = forward_backward(p, A, logB)
logL = HiddenMarkovModels.loglikelihood(fb)
@unpack α, β, γ, c, Bscaled, Bβscaled = fb
@unpack α, β, γ, c, Bβscaled = fb
T = length(obs_seq)

function logdensityof_hmm_pullback(ΔlogL)
Δp = ΔlogL .* Bβscaled[:, 1]
ΔA = ΔlogL .* α[:, 1] .* Bβscaled[:, 2]'
@views for t in 2:(T - 1)
ΔA .+= ΔlogL .* α[:, t] .* Bβscaled[:, t + 1]'
Δp = ΔlogL .* Bβscaled[1]
ΔA = ΔlogL .* α[1] .* Bβscaled[2]'
for t in 2:(T - 1)
ΔA .+= ΔlogL .* α[t] .* Bβscaled[t + 1]'
end
ΔlogB = ΔlogL .* γ
ΔlogB = ΔlogL .* reduce(hcat, γ)

Δlogdensityof = NoTangent()
_, Δhmm, Δobs_seq = pullback((Δp, ΔA, ΔlogB))
Expand Down
33 changes: 18 additions & 15 deletions src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ using Distributions:
UnivariateDistribution,
MultivariateDistribution,
MatrixDistribution
using LinearAlgebra: Diagonal, dot, mul!
using LinearAlgebra: Diagonal, axpy!, dot, mul!
using PrecompileTools: @compile_workload, @setup_workload
using Random: AbstractRNG, default_rng
using RequiredInterfaces: @required
using Requires: @require
using SimpleUnPack: @unpack
using SparseArrays: AbstractSparseArray, SparseMatrixCSC
using SparseArrays: nnz, nzrange, nonzeros, rowvals
using StatsAPI: StatsAPI, fit, fit!

export HMMs
Expand All @@ -51,6 +53,7 @@ include("utils/probvec.jl")
include("utils/transmat.jl")
include("utils/fit.jl")
include("utils/lightdiagnormal.jl")
include("utils/mul.jl")

include("inference/loglikelihoods.jl")
include("inference/forward.jl")
Expand All @@ -70,20 +73,20 @@ if !isdefined(Base, :get_extension)
end
end

@compile_workload begin
N, D, T = 5, 3, 100
p = rand_prob_vec(N)
A = rand_trans_mat(N)
dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]
hmm = HMM(p, A, dists)
# @compile_workload begin
# N, D, T = 5, 3, 100
# p = rand_prob_vec(N)
# A = rand_trans_mat(N)
# dists = [LightDiagNormal(randn(D), ones(D)) for i in 1:N]
# hmm = HMM(p, A, dists)

obs_seqs = [last(rand(hmm, T)) for _ in 1:3]
nb_seqs = 3
logdensityof(hmm, obs_seqs, nb_seqs)
forward(hmm, obs_seqs, nb_seqs)
viterbi(hmm, obs_seqs, nb_seqs)
forward_backward(hmm, obs_seqs, nb_seqs)
baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf)
end
# obs_seqs = [last(rand(hmm, T)) for _ in 1:3]
# nb_seqs = 3
# logdensityof(hmm, obs_seqs, nb_seqs)
# forward(hmm, obs_seqs, nb_seqs)
# viterbi(hmm, obs_seqs, nb_seqs)
# forward_backward(hmm, obs_seqs, nb_seqs)
# baum_welch(hmm, obs_seqs, nb_seqs; max_iterations=2, atol=-Inf)
# end

end
9 changes: 5 additions & 4 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,25 @@ function baum_welch!(
# Pre-allocate nearly all necessary memory
logB = loglikelihoods(hmm, obs_seqs[1])
fb = initialize_forward_backward(hmm, logB)
R = eltype(fb)

logBs = Vector{typeof(logB)}(undef, length(obs_seqs))
fbs = Vector{typeof(fb)}(undef, length(obs_seqs))
@threads for k in eachindex(obs_seqs)
for k in eachindex(obs_seqs)
logBs[k] = loglikelihoods(hmm, obs_seqs[k])
fbs[k] = forward_backward(hmm, logBs[k])
end

init_count, trans_count = initialize_states_stats(fbs)
state_marginals_concat = initialize_observations_stats(fbs)
init_count, trans_count = initialize_states_stats(R, hmm)
state_marginals_concat = initialize_observations_stats(R, hmm, obs_seqs)
obs_seqs_concat = reduce(vcat, obs_seqs)
logL = loglikelihood(fbs)
logL_evolution = [logL]

for iteration in 1:max_iterations
# E step
if iteration > 1
@threads for k in eachindex(obs_seqs, logBs, fbs)
for k in eachindex(obs_seqs, logBs, fbs)
loglikelihoods!(logBs[k], hmm, obs_seqs[k])
forward_backward!(fbs[k], hmm, logBs[k])
end
Expand Down
106 changes: 54 additions & 52 deletions src/inference/forward_backward.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
"""
ForwardBackwardStorage{R}
ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
Store forward-backward quantities with element type `R`.
# Fields
Let `X` denote the vector of hidden states and `Y` denote the vector of observations. The following fields are part of the API:
- `γ::Matrix{R}`: posterior one-state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`
- `ξ::Array{R,3}`: posterior two-state marginals `ξ[i,j,t] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`
- `γ::Vector{Vector{R}}`: posterior one-state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`
- `ξ::Vector{M}`: posterior two-state marginals `ξ[t][i,j] = ℙ(X[t:t+1]=(i,j) | Y[1:T])`
The following fields are internals and subject to change:
- `α::Matrix{R}`: scaled forward variables `α[i,t]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)
- `β::Matrix{R}`: scaled backward variables `β[i,t]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)
- `c::Vector{R}`: forward variable inverse normalizations `c[t] = 1 / sum(α[:, t])`
- `logm::Vector{R}`: maximum of the observation loglikelihoods `logB`
- `Bscaled::Matrix{R}`: numerically stabilized observation likelihoods `B`
- `Bβscaled::Matrix{R}`: numerically stabilized product `Bβ`
- `α`: scaled forward variables `α[t][i]` proportional to `ℙ(Y[1:t], X[t]=i)` (up to a function of `t`)
- `β`: scaled backward variables `β[t][i]` proportional to `ℙ(Y[t+1:T] | X[t]=i)` (up to a function of `t`)
- `c`: forward variable inverse normalizations `c[t] = 1 / sum(α[:, t])`
- `logm`: maximum of the observation loglikelihoods `logB`
- `Bscaled`: numerically stabilized observation likelihoods `B`
- `Bβscaled`: numerically stabilized product `Bβ`
"""
struct ForwardBackwardStorage{R}
α::Matrix{R}
β::Matrix{R}
γ::Matrix{R}
ξ::Array{R,3}
struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}}
α::Vector{Vector{R}}
β::Vector{Vector{R}}
γ::Vector{Vector{R}}
ξ::Vector{M}
c::Vector{R}
logm::Vector{R}
Bscaled::Matrix{R}
Bβscaled::Matrix{R}
Bscaled::Vector{Vector{R}}
Bβscaled::Vector{Vector{R}}
end

Base.length(fb::ForwardBackwardStorage) = size(fb.α, 1)
duration(fb::ForwardBackwardStorage) = size(fb.α, 2)
Base.eltype(fb::ForwardBackwardStorage{R}) where {R} = R
Base.length(fb::ForwardBackwardStorage) = length(first(fb.α))

Check warning on line 34 in src/inference/forward_backward.jl

View check run for this annotation

Codecov / codecov/patch

src/inference/forward_backward.jl#L34

Added line #L34 was not covered by tests
duration(fb::ForwardBackwardStorage) = length(fb.α)

function loglikelihood(fb::ForwardBackwardStorage{R}) where {R}
logL = -sum(log, fb.c) + sum(fb.logm)
return logL
end

function loglikelihood(fbs::Vector{ForwardBackwardStorage{R}}) where {R}
function loglikelihood(fbs::Vector{ForwardBackwardStorage{R,M}}) where {R,M}
logL = zero(R)
for fb in fbs
logL += loglikelihood(fb)
Expand All @@ -49,14 +50,16 @@ end
function initialize_forward_backward(p, A, logB)
N, T = size(logB)
R = promote_type(eltype(p), eltype(A), eltype(logB))
α = Matrix{R}(undef, N, T)
β = Matrix{R}(undef, N, T)
γ = Matrix{R}(undef, N, T)
ξ = Array{R,3}(undef, N, N, T - 1)
V = Vector{R}
M = typeof(similar(A, R))
α = V[Vector{R}(undef, N) for t in 1:T]
β = V[Vector{R}(undef, N) for t in 1:T]
γ = V[Vector{R}(undef, N) for t in 1:T]
ξ = M[similar(A, R) for t in 1:(T - 1)]
c = Vector{R}(undef, T)
logm = Vector{R}(undef, T)
Bscaled = Matrix{R}(undef, N, T)
Bβscaled = Matrix{R}(undef, N, T)
Bscaled = V[Vector{R}(undef, N) for t in 1:T]
Bβscaled = V[Vector{R}(undef, N) for t in 1:T]
return ForwardBackwardStorage(α, β, γ, ξ, c, logm, Bscaled, Bβscaled)
end

Expand All @@ -68,47 +71,46 @@ end

function forward!(fb::ForwardBackwardStorage, p, A, logB)
@unpack α, c, logm, Bscaled = fb
T = size(α, 2)
T = length)
maximum!(logm', logB)
Bscaled .= exp.(logB .- logm')
@views begin
α[:, 1] .= p .* Bscaled[:, 1]
c[1] = inv(sum(α[:, 1]))
α[:, 1] .*= c[1]
Bscaled[1] .= exp.(view(logB, :, 1) .- logm[1])
α[1] .= p .* Bscaled[1]
c[1] = inv(sum(α[1]))
α[1] .*= c[1]
for t in 1:(T - 1)
Bscaled[t + 1] .= exp.(view(logB, :, t + 1) .- logm[t + 1])
mul!(α[t + 1], A', α[t])
α[t + 1] .*= Bscaled[t + 1]
c[t + 1] = inv(sum(α[t + 1]))
α[t + 1] .*= c[t + 1]
end
@views for t in 1:(T - 1)
mul!(α[:, t + 1], A', α[:, t])
α[:, t + 1] .*= Bscaled[:, t + 1]
c[t + 1] = inv(sum(α[:, t + 1]))
α[:, t + 1] .*= c[t + 1]
end
check_no_nan(α)
return nothing
end

function backward!(fb::ForwardBackwardStorage{R}, A, logB) where {R}
@unpack β, c, Bscaled, Bβscaled = fb
T = size(β, 2)
β[:, T] .= c[T]
@views for t in (T - 1):-1:1
Bβscaled[:, t + 1] .= Bscaled[:, t + 1] .* β[:, t + 1]
mul!(β[:, t], A, Bβscaled[:, t + 1])
β[:, t] .*= c[t]
T = length)
β[T] .= c[T]
for t in (T - 1):-1:1
Bβscaled[t + 1] .= Bscaled[t + 1] .* β[t + 1]
mul!(β[t], A, Bβscaled[t + 1])
β[t] .*= c[t]
end
@views Bβscaled[:, 1] .= Bscaled[:, 1] .* β[:, 1]
check_no_nan(β)
Bβscaled[1] .= Bscaled[1] .* β[1]
return nothing
end

function marginals!(fb::ForwardBackwardStorage, A)
@unpack α, β, c, Bβscaled, γ, ξ = fb
N, T = size(γ)
γ .= α .* β ./ c'
check_no_nan(γ)
@views for t in 1:(T - 1)
ξ[:, :, t] .= α[:, t] .* A .* Bβscaled[:, t + 1]'
T = length(γ)
for t in 1:T
γ[t] .= α[t] .* β[t] ./ c[t]
end
for t in 1:(T - 1)
ξ[t] .= A
mul_rows!(ξ[t], α[t])
mul_cols!(ξ[t], Bβscaled[t + 1])
end
check_no_nan(ξ)
return nothing
end

Expand Down
33 changes: 18 additions & 15 deletions src/inference/sufficient_stats.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@

function initialize_states_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R}
N = length(first(fbs))
init_count = Vector{R}(undef, N)
trans_count = Matrix{R}(undef, N, N)
function initialize_states_stats(::Type{R}, hmm::AbstractHMM) where {R}
init_count = similar(initial_distribution(hmm), R)
trans_count = similar(transition_matrix(hmm), R)
return init_count, trans_count
end

function initialize_observations_stats(fbs::Vector{ForwardBackwardStorage{R}}) where {R}
N = length(first(fbs))
T_total = sum(duration, fbs)
function initialize_observations_stats(::Type{R}, hmm::AbstractHMM, obs_seqs) where {R}
N = length(hmm)
T_total = sum(length, obs_seqs)
state_marginals_concat = Matrix{R}(undef, N, T_total)
return state_marginals_concat
end

function update_states_stats!(
init_count, trans_count, fbs::Vector{ForwardBackwardStorage{R}}
) where {R}
init_count, trans_count, fbs::Vector{ForwardBackwardStorage{R,M}}
) where {R,M}
init_count .= zero(R)
for k in eachindex(fbs)
@views init_count .+= fbs[k].γ[:, 1]
init_count .+= fbs[k].γ[1]
end
trans_count .= zero(R)
for k in eachindex(fbs)
sum!(trans_count, fbs[k].ξ; init=false)
for t in eachindex(fbs[k].ξ)
mynonzeros(trans_count) .+= mynonzeros(fbs[k].ξ[t])
end
end
return nothing
end

function update_observations_stats!(
state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R}}
) where {R}
T = 1
state_marginals_concat, fbs::Vector{ForwardBackwardStorage{R,M}}
) where {R,M}
T = 0
for k in eachindex(fbs)
Tk = duration(fbs[k])
@views state_marginals_concat[:, T:(T + Tk - 1)] .= fbs[k].γ
for t in 1:Tk
@views state_marginals_concat[:, T + t] .= fbs[k].γ[t]
end
T += Tk
end
return nothing
Expand Down
7 changes: 5 additions & 2 deletions src/utils/check.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
function check_no_nan(a)
if any(isnan, a)
function check_no_nan(a::Number)
if isnan(a)
throw(OverflowError("Some values are NaN"))
end
return true
end

check_no_nan(a::AbstractArray) = all(check_no_nan, a)

function check_positive(a)
if any(!>(zero(eltype(a))), a)
throw(OverflowError("Some values are not positive"))
Expand Down
24 changes: 24 additions & 0 deletions src/utils/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function mul_rows!(A::AbstractMatrix, v::AbstractVector)
return A .*= v
end

function mul_cols!(A::AbstractMatrix, v::AbstractVector)
return A .*= v'
end

function mul_rows!(A::SparseMatrixCSC, v::AbstractVector)
for (k, i) in enumerate(rowvals(A))
A.nzval[k] *= v[i]
end
end

function mul_cols!(A::SparseMatrixCSC, v::AbstractVector)
for j in eachindex(v)
for k in nzrange(A, j)
A.nzval[k] *= v[j]
end
end
end

mynonzeros(x::AbstractArray) = x
mynonzeros(x::AbstractSparseArray) = nonzeros(x)
4 changes: 3 additions & 1 deletion test/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ function test_correctness(hmm, hmm_init; T)
@testset "Forward-backward" begin
γ_base = HMMBase.posteriors(hmm_base, obs_mat)
fb = @inferred forward_backward(hmm, obs_seq)
@test isapprox(fb.γ, γ_base')
@test isapprox(fb.γ[1], γ_base[1, :])
@test isapprox(fb.γ[T ÷ 2], γ_base[T ÷ 2, :])
@test isapprox(fb.γ[T], γ_base[T, :])
end

@testset "Baum-Welch" begin
Expand Down
Loading

0 comments on commit 1b9a7f1

Please sign in to comment.