From 64262cdc9daac0ff2828fbcbfa9f0678e342477e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 19:20:40 +0100 Subject: [PATCH 1/8] Speed up viterbi --- Project.toml | 2 +- docs/src/api.md | 15 +++++++ examples/basics.jl | 8 ++-- examples/types.jl | 3 ++ libs/HMMTest/src/coherence.jl | 1 + src/HiddenMarkovModels.jl | 2 +- src/inference/forward.jl | 4 +- src/inference/viterbi.jl | 41 ++++++++---------- src/precompile.jl | 2 +- src/types/abstract_hmm.jl | 52 +++++++++++++++++----- src/types/hmm.jl | 15 ++++++- src/utils/linalg.jl | 82 ++++++++++++++++++++++++++++++++--- src/utils/valid.jl | 21 ++++++--- test/correctness.jl | 13 ++++++ 14 files changed, 202 insertions(+), 59 deletions(-) diff --git a/Project.toml b/Project.toml index 82ef37cd..1d2abae5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.4.1" +version = "0.4.2" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/docs/src/api.md b/docs/src/api.md index 05714c63..586e0502 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,6 +36,13 @@ transition_matrix obs_distributions ``` +Optional log versions: + +```@docs +log_initialization +log_transition_matrix +``` + ## Utils ```@docs @@ -97,6 +104,7 @@ HiddenMarkovModels.baum_welch! ## Misc ```@docs +HiddenMarkovModels.valid_hmm HiddenMarkovModels.rand_prob_vec HiddenMarkovModels.rand_trans_mat HiddenMarkovModels.LightDiagNormal @@ -104,6 +112,13 @@ HiddenMarkovModels.LightCategorical HiddenMarkovModels.fit_in_sequence! ``` +## Internals + +```@docs +HiddenMarkovModels.mul_rows_cols! +HiddenMarkovModels.argmaxplus_mul! +``` + ## Index ```@index diff --git a/examples/basics.jl b/examples/basics.jl index 034583c7..7f14d863 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -19,10 +19,10 @@ rng = StableRNG(63); # ## Model #= -The package provides a versatile [`HMM`](@ref) type with three attributes: -- a vector of state initialization probabilities -- a matrix of state transition probabilities -- a vector of observation distributions, one for each state +The package provides a versatile [`HMM`](@ref) type with three main attributes: +- a vector `init` of state initialization probabilities +- a matrix `trans` of state transition probabilities +- a vector `dists` of observation distributions, one for each state Any scalar- or vector-valued distribution from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) can be used for the last part, as well as [Custom distributions](@ref). =# diff --git a/examples/types.jl b/examples/types.jl index 7160dd27..03f1da7f 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -6,6 +6,7 @@ Here we explain why playing with different number and array types can be useful using Distributions using HiddenMarkovModels +using HiddenMarkovModels: log_transition_matrix #src using HMMTest #src using LinearAlgebra using LogarithmicNumbers @@ -149,6 +150,8 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St # ## Tests #src +@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src + seq_ends = cumsum(rand(rng, 100:200, 100)); #src control_seqs = fill(nothing, length(seq_ends)); #src test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 2b4ce8d6..da111f3f 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -21,6 +21,7 @@ function test_equal_hmms( for control in control_seq trans1 = transition_matrix(hmm1, control) trans2 = transition_matrix(hmm2, control) + @test HMMs.mynnz(trans1) == HMMs.mynnz(trans2) if flip @test !isapprox(trans1, trans2; atol, norm=infnorm) else diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 8c4fedce..dbb66707 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -19,7 +19,7 @@ using FillArrays: Fill using LinearAlgebra: dot, ldiv!, lmul!, mul! using PrecompileTools: @compile_workload using Random: Random, AbstractRNG, default_rng -using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange +using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals using StatsAPI: StatsAPI, fit, fit! using StatsFuns: log2π diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 2849f795..c7d4883c 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -71,8 +71,8 @@ function forward!( Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm) trans = transition_matrix(hmm, control_seq[t]) - αₜ₊₁ = view(α, :, t + 1) - mul!(αₜ₊₁, trans', view(α, :, t)) + αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1) + mul!(αₜ₊₁, transpose(trans), αₜ) αₜ₊₁ .*= Bₜ₊₁ c[t + 1] = inv(sum(αₜ₊₁)) lmul!(c[t + 1], αₜ₊₁) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 1795d772..32e8ab34 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -13,7 +13,7 @@ struct ViterbiStorage{R} "one joint loglikelihood per pair of observation sequence and most likely state sequence" logL::Vector{R} logB::Matrix{R} - logϕ::Matrix{R} + ϕ::Matrix{R} ψ::Matrix{Int} end @@ -33,9 +33,9 @@ function initialize_viterbi( q = Vector{Int}(undef, T) logL = Vector{R}(undef, K) logB = Matrix{R}(undef, N, T) - logϕ = Matrix{R}(undef, N, T) + ϕ = Matrix{R}(undef, N, T) ψ = Matrix{Int}(undef, N, T) - return ViterbiStorage(q, logL, logB, logϕ, ψ) + return ViterbiStorage(q, logL, logB, ϕ, ψ) end """ @@ -49,31 +49,26 @@ function viterbi!( t1::Integer, t2::Integer; ) where {R} - (; q, logB, logϕ, ψ) = storage + (; q, logB, ϕ, ψ) = storage - obs_logdensities!(view(logB, :, t1), hmm, obs_seq[t1], control_seq[t1]) - init = initialization(hmm) - logϕ[:, t1] .= log.(init) .+ view(logB, :, t1) + logBₜ₁ = view(logB, :, t1) + obs_logdensities!(logBₜ₁, hmm, obs_seq[t1], control_seq[t1]) + loginit = log_initialization(hmm) + ϕ[:, t1] .= loginit .+ logBₜ₁ for t in (t1 + 1):t2 - obs_logdensities!(view(logB, :, t), hmm, obs_seq[t], control_seq[t]) - trans = transition_matrix(hmm, control_seq[t - 1]) - for j in 1:length(hmm) - i_max = 1 - score_max = logϕ[i_max, t - 1] + log(trans[i_max, j]) - for i in 2:length(hmm) - score = logϕ[i, t - 1] + log(trans[i, j]) - if score > score_max - score_max, i_max = score, i - end - end - ψ[j, t] = i_max - logϕ[j, t] = score_max + logB[j, t] - end + logBₜ = view(logB, :, t) + obs_logdensities!(logBₜ, hmm, obs_seq[t], control_seq[t]) + logtrans = log_transition_matrix(hmm, control_seq[t - 1]) + ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) + ψₜ = view(ψ, :, t) + argmaxplus_mul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) + ϕₜ .+= logBₜ end - q[t2] = argmax(view(logϕ, :, t2)) - logL = logϕ[q[t2], t2] + ϕₜ₂ = view(ϕ, :, t2) + q[t2] = argmax(ϕₜ₂) + logL = ϕ[q[t2], t2] for t in (t2 - 1):-1:t1 q[t] = ψ[q[t + 1], t + 1] end diff --git a/src/precompile.jl b/src/precompile.jl index 535deb05..9837c3c5 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -6,8 +6,8 @@ hmm = HMM(init, trans, dists) state_seq, obs_seq = rand(hmm, T) - logdensityof(hmm, obs_seq, state_seq) logdensityof(hmm, obs_seq) + joint_logdensityof(hmm, obs_seq, state_seq) forward(hmm, obs_seq) viterbi(hmm, obs_seq) forward_backward(hmm, obs_seq) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index ad3fe59a..afdacf1e 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -58,13 +58,34 @@ Return the vector of initial state probabilities for `hmm`. """ function initialization end +""" + log_initialization(hmm) + +Return the vector of initial state log-probabilities for `hmm`. + +Falls back on `initialization`. +""" +log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) + """ transition_matrix(hmm) transition_matrix(hmm, control) Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied). """ -transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm) +function transition_matrix end + +""" + log_transition_matrix(hmm) + log_transition_matrix(hmm, control) + +Return the matrix of state transition log-probabilities for `hmm` (possibly when `control` is applied). + +Falls back on `transition_matrix`. +""" +function log_transition_matrix(hmm::AbstractHMM, control) + return elementwise_log(transition_matrix(hmm, control)) +end """ obs_distributions(hmm) @@ -78,18 +99,12 @@ These distribution objects should implement - `DensityInterface.logdensityof(dist, obs)` for inference - `StatsAPI.fit!(dist, obs_seq, weight_seq)` for learning """ -obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm) +function obs_distributions end -function obs_logdensities!( - logb::AbstractVector{T}, hmm::AbstractHMM, obs, control -) where {T} - dists = obs_distributions(hmm, control) - @inbounds @simd for i in eachindex(logb, dists) - logb[i] = logdensityof(dists[i], obs) - end - @argcheck maximum(logb) < typemax(T) - return nothing -end +## Fallbacks for no control + +transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm) +obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm) """ StatsAPI.fit!( @@ -103,6 +118,19 @@ This function is allowed to reuse `fb_storage` as a scratch space, so its conten """ StatsAPI.fit! +## Fill logdensities + +function obs_logdensities!( + logb::AbstractVector{T}, hmm::AbstractHMM, obs, control +) where {T} + dists = obs_distributions(hmm, control) + @inbounds @simd for i in eachindex(logb, dists) + logb[i] = logdensityof(dists[i], obs) + end + @argcheck maximum(logb) < typemax(T) + return nothing +end + ## Sampling """ diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 955c8ace..a9310798 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -10,13 +10,19 @@ $(TYPEDFIELDS) struct HMM{V<:AbstractVector,M<:AbstractMatrix,VD<:AbstractVector} <: AbstractHMM "initial state probabilities" init::V - "state transition matrix" + "state transition probabilities" trans::M "observation distributions" dists::VD + "logarithms of initial state probabilities" + loginit::V + "logarithms of state transition probabilities" + logtrans::M function HMM(init::AbstractVector, trans::AbstractMatrix, dists::AbstractVector) - hmm = new{typeof(init),typeof(trans),typeof(dists)}(init, trans, dists) + hmm = new{typeof(init),typeof(trans),typeof(dists)}( + init, trans, dists, elementwise_log(init), elementwise_log(trans) + ) @argcheck valid_hmm(hmm) return hmm end @@ -34,7 +40,9 @@ function Base.show(io::IO, hmm::HMM) end initialization(hmm::HMM) = hmm.init +log_initialization(hmm::HMM) = hmm.loginit transition_matrix(hmm::HMM) = hmm.trans +log_transition_matrix(hmm::HMM) = hmm.logtrans obs_distributions(hmm::HMM) = hmm.dists ## Fitting @@ -69,6 +77,9 @@ function StatsAPI.fit!( for i in 1:length(hmm) fit_in_sequence!(hmm.dists, i, obs_seq, view(γ, i, :)) end + # Update logs + hmm.loginit .= log.(hmm.init) + mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans)) # Safety check @argcheck valid_hmm(hmm) return nothing diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index e3fe2b6f..031887dc 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -3,26 +3,94 @@ sum_to_one!(x) = ldiv!(sum(x), x) mynonzeros(x::AbstractArray) = x mynonzeros(x::AbstractSparseArray) = nonzeros(x) -mynnz(x) = length(mynonzeros(x)) +mynnz(x::AbstractArray) = length(mynonzeros(x)) +elementwise_log(x::AbstractArray) = log.(x) + +function elementwise_log(A::SparseMatrixCSC) + return SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, log.(A.nzval)) +end + +""" + mul_rows_cols!(B, l, A, r) + +Perform the in-place operation `B .= l .* A .* transpose(r)`. +""" function mul_rows_cols!( B::AbstractMatrix, l::AbstractVector, A::AbstractMatrix, r::AbstractVector ) - B .= l .* A .* r' - return nothing + B .= l .* A .* transpose(r) + return B end function mul_rows_cols!( B::SparseMatrixCSC, l::AbstractVector, A::SparseMatrixCSC, r::AbstractVector ) - @argcheck size(B) == size(A) == (length(l), length(r)) + @argcheck axes(A, 1) == eachindex(r) + @argcheck axes(A, 2) == eachindex(l) + @argcheck size(A) == size(B) @argcheck nnz(B) == nnz(A) + Brv = rowvals(B) + Bnz = nonzeros(B) + Anz = nonzeros(A) for j in axes(B, 2) @argcheck nzrange(B, j) == nzrange(A, j) for k in nzrange(B, j) - i = B.rowval[k] - B.nzval[k] = l[i] * A.nzval[k] * r[j] + i = Brv[k] + Bnz[k] = l[i] * Anz[k] * r[j] + end + end + return B +end + +""" + argmaxplus_mul!(y, ind, A, x) + +Perform the in-place multiplication `A * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each row in `ind`. +""" +function argmaxplus_mul!( + y::AbstractVector{R}, + ind::AbstractVector{<:Integer}, + A::AbstractMatrix, + x::AbstractVector, +) where {R} + @argcheck axes(A, 1) == eachindex(y) + @argcheck axes(A, 2) == eachindex(x) + y .= typemin(R) + ind .= 0 + for j in axes(A, 2) + for i in axes(A, 1) + z = A[i, j] + x[j] + if z > y[i] + y[i] = z + ind[i] = j + end + end + end + return y +end + +function argmaxplus_mul!( + y::AbstractVector{R}, + ind::AbstractVector{<:Integer}, + A::SparseMatrixCSC, + x::AbstractVector, +) where {R} + @argcheck axes(A, 1) == eachindex(y) + @argcheck axes(A, 2) == eachindex(x) + y .= typemin(R) + ind .= 0 + Anz = nonzeros(A) + Arv = rowvals(A) + for j in axes(A, 2) + for k in nzrange(A, j) + i, a = Arv[k], Anz[k] + z = a + x[j] + if z > y[i] + y[i] = z + ind[i] = j + end end end - return nothing + return y end diff --git a/src/utils/valid.jl b/src/utils/valid.jl index ed147691..2ee9830f 100644 --- a/src/utils/valid.jl +++ b/src/utils/valid.jl @@ -15,14 +15,23 @@ function valid_dists(d::AbstractVector) return true end +""" + valid_hmm(hmm) + +Perform some checks to rule out obvious inconsistencies with an `AbstractHMM` object. +""" function valid_hmm(hmm::AbstractHMM, control=nothing) init = initialization(hmm) trans = transition_matrix(hmm, control) dists = obs_distributions(hmm, control) - return ( - length(init) == length(dists) == size(trans, 1) == size(trans, 2) && - valid_prob_vec(init) && - valid_trans_mat(trans) && - valid_dists(dists) - ) + if !(length(init) == length(dists) == size(trans, 1) == size(trans, 2)) + return false + elseif !valid_prob_vec(init) + return false + elseif !valid_trans_mat(trans) + return false + elseif !valid_dists(dists) + return false + end + return true end diff --git a/test/correctness.jl b/test/correctness.jl index 291f6e23..3ebe2b5d 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -85,3 +85,16 @@ end test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) end + +@testset "Normal (sparse)" begin + dists = [Normal(μ[1][1]), Normal(μ[2][1])] + dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] + + hmm = HMM(init, sparse(trans), dists) + hmm_guess = HMM(init_guess, trans_guess, dists_guess) + + test_identical_hmmbase(rng, hmm, T; hmm_guess) + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) +end From 58da66e084eb846011eb2c639d8b1f5ba711b835 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 19:26:31 +0100 Subject: [PATCH 2/8] Force controls everywhere --- Project.toml | 2 +- examples/controlled.jl | 5 +++-- src/types/abstract_hmm.jl | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 1d2abae5..17fd0925 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.4.2" +version = "0.5.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/examples/controlled.jl b/examples/controlled.jl index a3c9ca72..c7b70857 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -32,14 +32,15 @@ struct ControlledGaussianHMM{T} <: AbstractHMM end #= -In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$. +In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$. +Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one. =# function HMMs.initialization(hmm::ControlledGaussianHMM) return hmm.init end -function HMMs.transition_matrix(hmm::ControlledGaussianHMM) +function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector) return hmm.trans end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index afdacf1e..6f3f8e53 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -103,8 +103,9 @@ function obs_distributions end ## Fallbacks for no control -transition_matrix(hmm::AbstractHMM, control) = transition_matrix(hmm) -obs_distributions(hmm::AbstractHMM, control) = obs_distributions(hmm) +transition_matrix(hmm::AbstractHMM, ::Nothing) = transition_matrix(hmm) +log_transition_matrix(hmm::AbstractHMM, ::Nothing) = log_transition_matrix(hmm) +obs_distributions(hmm::AbstractHMM, ::Nothing) = obs_distributions(hmm) """ StatsAPI.fit!( From 5c0367c0cae503d3734e30dce38c133e5fb27a9e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:21:30 +0100 Subject: [PATCH 3/8] Remove sparse special case --- README.md | 4 ++-- benchmark/Manifest.toml | 2 +- docs/src/api.md | 15 +++++---------- examples/basics.jl | 4 ++-- examples/controlled.jl | 12 ++++++------ examples/interfaces.jl | 6 +++--- examples/temporal.jl | 6 +++--- examples/types.jl | 14 +++++++++----- libs/HMMTest/src/coherence.jl | 4 +++- libs/HMMTest/src/hmmbase.jl | 3 +-- src/HiddenMarkovModels.jl | 2 +- src/inference/logdensity.jl | 9 +++++---- src/inference/viterbi.jl | 2 +- src/types/hmm.jl | 2 +- src/utils/linalg.jl | 29 ----------------------------- 15 files changed, 43 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index 9f78227e..9928a9c3 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,8 @@ Then, you can create your first model as follows: ```julia using Distributions, HiddenMarkovModels -init = [0.4, 0.6] -trans = [0.9 0.1; 0.2 0.8] +init = [0.6, 0.4] +trans = [0.7 0.3; 0.2 0.8] dists = [Normal(-1.0), Normal(1.0)] hmm = HMM(init, trans, dists) ``` diff --git a/benchmark/Manifest.toml b/benchmark/Manifest.toml index 8e19bf18..4d1de94f 100644 --- a/benchmark/Manifest.toml +++ b/benchmark/Manifest.toml @@ -169,7 +169,7 @@ version = "0.1.0" deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PrecompileTools", "Random", "SparseArrays", "StatsAPI", "StatsFuns"] path = ".." uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" -version = "0.4.0" +version = "0.5.0" weakdeps = ["Distributions"] [deps.HiddenMarkovModels.extensions] diff --git a/docs/src/api.md b/docs/src/api.md index 586e0502..9459960f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,13 +36,6 @@ transition_matrix obs_distributions ``` -Optional log versions: - -```@docs -log_initialization -log_transition_matrix -``` - ## Utils ```@docs @@ -101,20 +94,22 @@ HiddenMarkovModels.forward_backward! HiddenMarkovModels.baum_welch! ``` -## Misc +## Miscellaneous ```@docs HiddenMarkovModels.valid_hmm HiddenMarkovModels.rand_prob_vec HiddenMarkovModels.rand_trans_mat -HiddenMarkovModels.LightDiagNormal -HiddenMarkovModels.LightCategorical HiddenMarkovModels.fit_in_sequence! ``` ## Internals ```@docs +HiddenMarkovModels.LightDiagNormal +HiddenMarkovModels.LightCategorical +HiddenMarkovModels.log_initialization +HiddenMarkovModels.log_transition_matrix HiddenMarkovModels.mul_rows_cols! HiddenMarkovModels.argmaxplus_mul! ``` diff --git a/examples/basics.jl b/examples/basics.jl index 7f14d863..13f1b45f 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -28,7 +28,7 @@ Any scalar- or vector-valued distribution from [Distributions.jl](https://github =# init = [0.6, 0.4] -trans = [0.7 0.3; 0.3 0.7] +trans = [0.7 0.3; 0.2 0.8] dists = [MvNormal([-0.5, -0.8], I), MvNormal([0.5, 0.8], I)] hmm = HMM(init, trans, dists) @@ -142,7 +142,7 @@ Since it is a local optimization procedure, it requires a starting point that is =# init_guess = [0.5, 0.5] -trans_guess = [0.6 0.4; 0.4 0.6] +trans_guess = [0.6 0.4; 0.3 0.7] dists_guess = [MvNormal([-0.4, -0.7], I), MvNormal([0.4, 0.7], I)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); diff --git a/examples/controlled.jl b/examples/controlled.jl index c7b70857..1f2451b9 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -55,8 +55,8 @@ In this case, the transition matrix does not depend on the control. # ## Simulation d = 3 -init = [0.8, 0.2] -trans = [0.7 0.3; 0.3 0.7] +init = [0.6, 0.4] +trans = [0.7 0.3; 0.2 0.8] dist_coeffs = [-ones(d), ones(d)] hmm = ControlledGaussianHMM(init, trans, dist_coeffs); @@ -123,9 +123,9 @@ end Now we put it to the test. =# -init_guess = [0.7, 0.3] -trans_guess = [0.6 0.4; 0.4 0.6] -dist_coeffs_guess = [-0.7 * ones(d), 0.7 * ones(d)] +init_guess = [0.5, 0.5] +trans_guess = [0.6 0.4; 0.3 0.7] +dist_coeffs_guess = [-1.1 * ones(d), 1.1 * ones(d)] hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess); #- @@ -137,7 +137,7 @@ first(loglikelihood_evolution), last(loglikelihood_evolution) How did we perform? =# -cat(transition_matrix(hmm_est), transition_matrix(hmm); dims=3) +cat(hmm_est.trans, hmm.trans; dims=3) #- diff --git a/examples/interfaces.jl b/examples/interfaces.jl index 18e0d7b5..ba306051 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -82,7 +82,7 @@ Let's put it to the test. =# init = [0.6, 0.4] -trans = [0.7 0.3; 0.3 0.7] +trans = [0.7 0.3; 0.2 0.8] dists = [StuffDist(-1.0), StuffDist(+1.0)] hmm = HMM(init, trans, dists); @@ -104,8 +104,8 @@ If we implement `fit!`, Baum-Welch also works seamlessly. =# init_guess = [0.5, 0.5] -trans_guess = [0.6 0.4; 0.4 0.6] -dists_guess = [StuffDist(-0.7), StuffDist(+0.7)] +trans_guess = [0.6 0.4; 0.3 0.7] +dists_guess = [StuffDist(-1.1), StuffDist(+1.1)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #- diff --git a/examples/temporal.jl b/examples/temporal.jl index dee15c28..9c9549f4 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -55,7 +55,7 @@ end # ## Simulation init = [0.6, 0.4] -trans_per = ([0.7 0.3; 0.3 0.7], [0.3 0.7; 0.7 0.3]) +trans_per = ([0.7 0.3; 0.2 0.8], [0.3 0.7; 0.8 0.2]) dists_per = ([Normal(-1.0), Normal(-2.0)], [Normal(+1.0), Normal(+2.0)]) hmm = PeriodicHMM(init, trans_per, dists_per); @@ -152,8 +152,8 @@ Now let's test our procedure with a reasonable guess. =# init_guess = [0.7, 0.3] -trans_per_guess = ([0.6 0.4; 0.4 0.6], [0.4 0.6; 0.6 0.4]) -dists_per_guess = ([Normal(-0.7), Normal(-1.7)], [Normal(+0.7), Normal(+1.7)]) +trans_per_guess = ([0.6 0.4; 0.3 0.7], [0.4 0.6; 0.7 0.3]) +dists_per_guess = ([Normal(-1.1), Normal(-2.1)], [Normal(+1.1), Normal(+2.1)]) hmm_guess = PeriodicHMM(init_guess, trans_per_guess, dists_per_guess); #= diff --git a/examples/types.jl b/examples/types.jl index 03f1da7f..d4966a5f 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -41,7 +41,7 @@ To give an example, let us first generate some data from a vanilla HMM. =# init = [0.6, 0.4] -trans = [0.7 0.3; 0.3 0.7] +trans = [0.7 0.3; 0.2 0.8] dists = [Normal(-1.0), Normal(1.0)] hmm = HMM(init, trans, dists) state_seq, obs_seq = rand(rng, hmm, 100); @@ -58,6 +58,10 @@ hmm_uncertain = HMM(init, trans, dists_guess) Every quantity we compute with this new HMM will have propagated uncertainties around it. =# +logdensityof(hmm, obs_seq) + +#- + logdensityof(hmm_uncertain, obs_seq) #= @@ -130,7 +134,7 @@ trans_guess = sparse([ 0 0.6 0.4 0.4 0 0.6 ]) -dists_guess = [Normal(1.2), Normal(2.2), Normal(3.2)] +dists_guess = [Normal(1.1), Normal(2.1), Normal(3.1)] hmm_guess = HMM(init_guess, trans_guess, dists_guess); #- @@ -150,12 +154,12 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St # ## Tests #src -@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src +@test_broken nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src seq_ends = cumsum(rand(rng, 100:200, 100)); #src -control_seqs = fill(nothing, length(seq_ends)); #src +control_seq = fill(nothing, last(seq_ends)); #src test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src # https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index da111f3f..2a39e336 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -21,7 +21,9 @@ function test_equal_hmms( for control in control_seq trans1 = transition_matrix(hmm1, control) trans2 = transition_matrix(hmm2, control) - @test HMMs.mynnz(trans1) == HMMs.mynnz(trans2) + if typeof(trans1) == typeof(trans2) + @test HMMs.mynnz(trans1) == HMMs.mynnz(trans2) + end if flip @test !isapprox(trans1, trans2; atol, norm=infnorm) else diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/src/hmmbase.jl index fe1e48e3..808e2e0c 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/src/hmmbase.jl @@ -26,8 +26,7 @@ function test_identical_hmmbase( q_base = HMMBase.viterbi(hmm_base, obs_mat) q, logL_viterbi = viterbi(hmm, obs_seq; seq_ends) - # Viterbi decoding can vary in case of (infrequent) ties - @test mean(q[1:T] .== q_base) > 0.9 && mean(q[(T + 1):(2T)] .== q_base) > 0.9 + @test all(q[1:T] .== q_base) && all(q[(T + 1):(2T)] .== q_base) γ_base = HMMBase.posteriors(hmm_base, obs_mat) γ, logL_forward_backward = forward_backward(hmm, obs_seq; seq_ends) diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index dbb66707..db5c7329 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof using DocStringExtensions using FillArrays: Fill -using LinearAlgebra: dot, ldiv!, lmul!, mul! +using LinearAlgebra: Transpose, dot, ldiv!, lmul!, mul!, parent using PrecompileTools: @compile_workload using Random: Random, AbstractRNG, default_rng using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index f43fb25c..d731ed5b 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -30,12 +30,13 @@ function joint_logdensityof( for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) # Initialization - init = initialization(hmm) - logL += log(init[state_seq[t1]]) + loginit = log_initialization(hmm) + logL += loginit[state_seq[t1]] # Transitions for t in t1:(t2 - 1) - trans = transition_matrix(hmm, control_seq[t]) - logL += log(trans[state_seq[t], state_seq[t + 1]]) + logtrans = log.(transition_matrix(hmm, control_seq[t])) + # logtrans = log_transition_matrix(hmm, control_seq[t]) + logL += logtrans[state_seq[t], state_seq[t + 1]] end # Observations for t in t1:t2 diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 32e8ab34..abf254e8 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -62,7 +62,7 @@ function viterbi!( logtrans = log_transition_matrix(hmm, control_seq[t - 1]) ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) ψₜ = view(ψ, :, t) - argmaxplus_mul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) + argmaxplus_mul!(ϕₜ, ψₜ, transpose(logtrans), ϕₜ₋₁) ϕₜ .+= logBₜ end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index a9310798..476eaeae 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -79,7 +79,7 @@ function StatsAPI.fit!( end # Update logs hmm.loginit .= log.(hmm.init) - mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans)) + hmm.logtrans .= log.(hmm.trans) # Safety check @argcheck valid_hmm(hmm) return nothing diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 031887dc..07f8f31a 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -7,10 +7,6 @@ mynnz(x::AbstractArray) = length(mynonzeros(x)) elementwise_log(x::AbstractArray) = log.(x) -function elementwise_log(A::SparseMatrixCSC) - return SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, log.(A.nzval)) -end - """ mul_rows_cols!(B, l, A, r) @@ -69,28 +65,3 @@ function argmaxplus_mul!( end return y end - -function argmaxplus_mul!( - y::AbstractVector{R}, - ind::AbstractVector{<:Integer}, - A::SparseMatrixCSC, - x::AbstractVector, -) where {R} - @argcheck axes(A, 1) == eachindex(y) - @argcheck axes(A, 2) == eachindex(x) - y .= typemin(R) - ind .= 0 - Anz = nonzeros(A) - Arv = rowvals(A) - for j in axes(A, 2) - for k in nzrange(A, j) - i, a = Arv[k], Anz[k] - z = a + x[j] - if z > y[i] - y[i] = z - ind[i] = j - end - end - end - return y -end From cfb7cf6ad256ee1be043ceca58f7bad6c22f8b0a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:24:28 +0100 Subject: [PATCH 4/8] Fix joint logdensityof --- src/inference/logdensity.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index d731ed5b..ec970cb3 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -34,8 +34,7 @@ function joint_logdensityof( logL += loginit[state_seq[t1]] # Transitions for t in t1:(t2 - 1) - logtrans = log.(transition_matrix(hmm, control_seq[t])) - # logtrans = log_transition_matrix(hmm, control_seq[t]) + logtrans = log_transition_matrix(hmm, control_seq[t]) logL += logtrans[state_seq[t], state_seq[t + 1]] end # Observations From 559b6dc13fc69eead7b9fa67f17e0f27aa270ac7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:25:13 +0100 Subject: [PATCH 5/8] Fix joint logdensityof again --- src/inference/logdensity.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index ec970cb3..f43fb25c 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -30,12 +30,12 @@ function joint_logdensityof( for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) # Initialization - loginit = log_initialization(hmm) - logL += loginit[state_seq[t1]] + init = initialization(hmm) + logL += log(init[state_seq[t1]]) # Transitions for t in t1:(t2 - 1) - logtrans = log_transition_matrix(hmm, control_seq[t]) - logL += logtrans[state_seq[t], state_seq[t + 1]] + trans = transition_matrix(hmm, control_seq[t]) + logL += log(trans[state_seq[t], state_seq[t + 1]]) end # Observations for t in t1:t2 From 2746f09e9e820f97661e65b2a02bd1a4ae1da636 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:33:52 +0100 Subject: [PATCH 6/8] Give PR write permissions to benchmark workflow --- .github/workflows/benchmark.yml | 3 +++ libs/HMMBenchmark/src/hiddenmarkovmodels.jl | 16 ++++++++-------- libs/HMMComparison/src/dynamax.jl | 8 ++++---- libs/HMMComparison/src/hmmbase.jl | 8 ++++---- libs/HMMComparison/src/hmmlearn.jl | 8 ++++---- libs/HMMComparison/src/pomegranate.jl | 6 +++--- 6 files changed, 26 insertions(+), 23 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 97dcad51..c4caf436 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -8,6 +8,9 @@ on: jobs: Benchmark: runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write if: contains(github.event.pull_request.labels.*.name, 'run benchmark') steps: - uses: actions/checkout@v2 diff --git a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl index 63ca122b..88d6dfa8 100644 --- a/libs/HMMBenchmark/src/hiddenmarkovmodels.jl +++ b/libs/HMMBenchmark/src/hiddenmarkovmodels.jl @@ -43,12 +43,12 @@ function build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "forward!" in algos benchs["forward!"] = @benchmarkable begin forward!(f_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( f_storage = initialize_forward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end @@ -56,12 +56,12 @@ function build_benchmarkables( if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "viterbi!" in algos benchs["viterbi!"] = @benchmarkable begin viterbi!(v_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( v_storage = initialize_viterbi($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) ) end @@ -69,12 +69,12 @@ function build_benchmarkables( if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin forward_backward($hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "forward_backward!" in algos benchs["forward_backward!"] = @benchmarkable begin forward_backward!(fb_storage, $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( fb_storage = initialize_forward_backward( $hmm, $obs_seq, $control_seq; seq_ends=$seq_ends ) @@ -92,7 +92,7 @@ function build_benchmarkables( atol=-Inf, loglikelihood_increasing=false, ) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "baum_welch!" in algos benchs["baum_welch!"] = @benchmarkable begin @@ -107,7 +107,7 @@ function build_benchmarkables( atol=-Inf, loglikelihood_increasing=false, ) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( hmm_guess = build_model($implem, $instance, $params); fb_storage = initialize_forward_backward( hmm_guess, $obs_seq, $control_seq; seq_ends=$seq_ends diff --git a/libs/HMMComparison/src/dynamax.jl b/libs/HMMComparison/src/dynamax.jl index d60e8ef3..4a845911 100644 --- a/libs/HMMComparison/src/dynamax.jl +++ b/libs/HMMComparison/src/dynamax.jl @@ -48,7 +48,7 @@ function HMMBenchmark.build_benchmarkables( filter_vmap = jax.jit(jax.vmap(hmm.filter; in_axes=pylist((pybuiltins.None, 0)))) benchs["forward"] = @benchmarkable begin $(filter_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "viterbi" in algos @@ -57,7 +57,7 @@ function HMMBenchmark.build_benchmarkables( ) benchs["viterbi"] = @benchmarkable begin $(most_likely_states_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "forward_backward" in algos @@ -66,7 +66,7 @@ function HMMBenchmark.build_benchmarkables( ) benchs["forward_backward"] = @benchmarkable begin $(smoother_vmap)($dyn_params, $obs_tens_jax_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "baum_welch" in algos @@ -78,7 +78,7 @@ function HMMBenchmark.build_benchmarkables( num_iters=$bw_iter, verbose=false, ) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( tup = build_model($implem, $instance, $params); hmm_guess = tup[1]; dyn_params_guess = tup[2]; diff --git a/libs/HMMComparison/src/hmmbase.jl b/libs/HMMComparison/src/hmmbase.jl index 2c260e89..4608071a 100644 --- a/libs/HMMComparison/src/hmmbase.jl +++ b/libs/HMMComparison/src/hmmbase.jl @@ -41,7 +41,7 @@ function HMMBenchmark.build_benchmarkables( @threads for k in eachindex($obs_mats) HMMBase.forward($hmm, $obs_mats[k]) end - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "viterbi" in algos @@ -49,7 +49,7 @@ function HMMBenchmark.build_benchmarkables( @threads for k in eachindex($obs_mats) HMMBase.viterbi($hmm, $obs_mats[k]) end - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "forward_backward" in algos @@ -57,13 +57,13 @@ function HMMBenchmark.build_benchmarkables( @threads for k in eachindex($obs_mats) HMMBase.posteriors($hmm, $obs_mats[k]) end - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin HMMBase.fit_mle($hmm, $obs_mat_concat; maxiter=$bw_iter, tol=-Inf) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end return benchs diff --git a/libs/HMMComparison/src/hmmlearn.jl b/libs/HMMComparison/src/hmmlearn.jl index 58b1f053..610c9129 100644 --- a/libs/HMMComparison/src/hmmlearn.jl +++ b/libs/HMMComparison/src/hmmlearn.jl @@ -45,25 +45,25 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.score)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "viterbi" in algos benchs["viterbi"] = @benchmarkable begin $(hmm.decode)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin $(hmm.predict_proba)($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_mat_concat_py, $obs_mat_len_py) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( hmm_guess = build_model($implem, $instance, $params) ) end diff --git a/libs/HMMComparison/src/pomegranate.jl b/libs/HMMComparison/src/pomegranate.jl index adf27f53..d5939919 100644 --- a/libs/HMMComparison/src/pomegranate.jl +++ b/libs/HMMComparison/src/pomegranate.jl @@ -57,19 +57,19 @@ function HMMBenchmark.build_benchmarkables( if "forward" in algos benchs["forward"] = @benchmarkable begin $(hmm.forward)($obs_tens_torch_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "forward_backward" in algos benchs["forward_backward"] = @benchmarkable begin $(hmm.forward_backward)($obs_tens_torch_py) - end evals = 1 samples = 10 + end evals = 1 samples = 100 end if "baum_welch" in algos benchs["baum_welch"] = @benchmarkable begin hmm_guess.fit($obs_tens_torch_py) - end evals = 1 samples = 10 setup = ( + end evals = 1 samples = 100 setup = ( hmm_guess = build_model($implem, $instance, $params) ) end From 083bbf98ed9df6af759f3031b54f0d0a9de2f72c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:45:33 +0100 Subject: [PATCH 7/8] Special case for sparse matrices --- examples/types.jl | 2 +- src/inference/viterbi.jl | 2 +- src/types/hmm.jl | 2 +- src/utils/linalg.jl | 47 ++++++++++++++++++++++++++++++++-------- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/examples/types.jl b/examples/types.jl index d4966a5f..0945be63 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -154,7 +154,7 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St # ## Tests #src -@test_broken nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src +@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src seq_ends = cumsum(rand(rng, 100:200, 100)); #src control_seq = fill(nothing, last(seq_ends)); #src diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index abf254e8..09e18a26 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -62,7 +62,7 @@ function viterbi!( logtrans = log_transition_matrix(hmm, control_seq[t - 1]) ϕₜ, ϕₜ₋₁ = view(ϕ, :, t), view(ϕ, :, t - 1) ψₜ = view(ψ, :, t) - argmaxplus_mul!(ϕₜ, ψₜ, transpose(logtrans), ϕₜ₋₁) + argmaxplus_transmul!(ϕₜ, ψₜ, logtrans, ϕₜ₋₁) ϕₜ .+= logBₜ end diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 476eaeae..a9310798 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -79,7 +79,7 @@ function StatsAPI.fit!( end # Update logs hmm.loginit .= log.(hmm.init) - hmm.logtrans .= log.(hmm.trans) + mynonzeros(hmm.logtrans) .= log.(mynonzeros(hmm.trans)) # Safety check @argcheck valid_hmm(hmm) return nothing diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 07f8f31a..506b29e4 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -7,6 +7,10 @@ mynnz(x::AbstractArray) = length(mynonzeros(x)) elementwise_log(x::AbstractArray) = log.(x) +function elementwise_log(A::SparseMatrixCSC) + return SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, log.(A.nzval)) +end + """ mul_rows_cols!(B, l, A, r) @@ -40,26 +44,51 @@ function mul_rows_cols!( end """ - argmaxplus_mul!(y, ind, A, x) + argmaxplus_transmul!(y, ind, A, x) -Perform the in-place multiplication `A * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each row in `ind`. +Perform the in-place multiplication `transpose(A) * x` _in the sense of max-plus algebra_, store the result in `y`, and store the index of the maximum for each component of `y` in `ind`. """ -function argmaxplus_mul!( +function argmaxplus_transmul!( y::AbstractVector{R}, ind::AbstractVector{<:Integer}, A::AbstractMatrix, x::AbstractVector, ) where {R} - @argcheck axes(A, 1) == eachindex(y) - @argcheck axes(A, 2) == eachindex(x) + @argcheck axes(A, 1) == eachindex(x) + @argcheck axes(A, 2) == eachindex(y) y .= typemin(R) ind .= 0 for j in axes(A, 2) for i in axes(A, 1) - z = A[i, j] + x[j] - if z > y[i] - y[i] = z - ind[i] = j + z = A[i, j] + x[i] + if z > y[j] + y[j] = z + ind[j] = i + end + end + end + return y +end + +function argmaxplus_transmul!( + y::AbstractVector{R}, + ind::AbstractVector{<:Integer}, + A::SparseMatrixCSC, + x::AbstractVector, +) where {R} + @argcheck axes(A, 1) == eachindex(x) + @argcheck axes(A, 2) == eachindex(y) + Anz = nonzeros(A) + Arv = rowvals(A) + y .= typemin(R) + ind .= 0 + for j in axes(A, 2) + for k in nzrange(A, j) + i = Arv[k] + z = Anz[k] + x[i] + if z > y[j] + y[j] = z + ind[j] = i end end end From 4bf59edebcdeb5048f3e9071d5eca93cce10770a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 24 Feb 2024 20:56:17 +0100 Subject: [PATCH 8/8] Fix docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9459960f..8f0f9d70 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -111,7 +111,7 @@ HiddenMarkovModels.LightCategorical HiddenMarkovModels.log_initialization HiddenMarkovModels.log_transition_matrix HiddenMarkovModels.mul_rows_cols! -HiddenMarkovModels.argmaxplus_mul! +HiddenMarkovModels.argmaxplus_transmul! ``` ## Index